1//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Transforms/DialectConversion.h"
10#include "mlir/Config/mlir-config.h"
11#include "mlir/IR/Block.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/IR/Dominance.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/Iterators.h"
17#include "mlir/Interfaces/FunctionInterfaces.h"
18#include "mlir/Rewrite/PatternApplicator.h"
19#include "llvm/ADT/SmallPtrSet.h"
20#include "llvm/Support/Debug.h"
21#include "llvm/Support/FormatVariadic.h"
22#include "llvm/Support/SaveAndRestore.h"
23#include "llvm/Support/ScopedPrinter.h"
24#include <optional>
25
26using namespace mlir;
27using namespace mlir::detail;
28
29#define DEBUG_TYPE "dialect-conversion"
30
31/// A utility function to log a successful result for the given reason.
32template <typename... Args>
33static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
34 LLVM_DEBUG({
35 os.unindent();
36 os.startLine() << "} -> SUCCESS";
37 if (!fmt.empty())
38 os.getOStream() << " : "
39 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
40 os.getOStream() << "\n";
41 });
42}
43
44/// A utility function to log a failure result for the given reason.
45template <typename... Args>
46static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
47 LLVM_DEBUG({
48 os.unindent();
49 os.startLine() << "} -> FAILURE : "
50 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
51 << "\n";
52 });
53}
54
55/// Helper function that computes an insertion point where the given value is
56/// defined and can be used without a dominance violation.
57static OpBuilder::InsertPoint computeInsertPoint(Value value) {
58 Block *insertBlock = value.getParentBlock();
59 Block::iterator insertPt = insertBlock->begin();
60 if (OpResult inputRes = dyn_cast<OpResult>(Val&: value))
61 insertPt = ++inputRes.getOwner()->getIterator();
62 return OpBuilder::InsertPoint(insertBlock, insertPt);
63}
64
65/// Helper function that computes an insertion point where the given values are
66/// defined and can be used without a dominance violation.
67static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
68 assert(!vals.empty() && "expected at least one value");
69 DominanceInfo domInfo;
70 OpBuilder::InsertPoint pt = computeInsertPoint(value: vals.front());
71 for (Value v : vals.drop_front()) {
72 // Choose the "later" insertion point.
73 OpBuilder::InsertPoint nextPt = computeInsertPoint(value: v);
74 if (domInfo.dominates(aBlock: pt.getBlock(), aIt: pt.getPoint(), bBlock: nextPt.getBlock(),
75 bIt: nextPt.getPoint())) {
76 // pt is before nextPt => choose nextPt.
77 pt = nextPt;
78 } else {
79#ifndef NDEBUG
80 // nextPt should be before pt => choose pt.
81 // If pt, nextPt are no dominance relationship, then there is no valid
82 // insertion point at which all given values are defined.
83 bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
84 pt.getBlock(), pt.getPoint());
85 assert(dom && "unable to find valid insertion point");
86#endif // NDEBUG
87 }
88 }
89 return pt;
90}
91
92//===----------------------------------------------------------------------===//
93// ConversionValueMapping
94//===----------------------------------------------------------------------===//
95
96/// A vector of SSA values, optimized for the most common case of a single
97/// value.
98using ValueVector = SmallVector<Value, 1>;
99
100namespace {
101
102/// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
103struct ValueVectorMapInfo {
104 static ValueVector getEmptyKey() { return ValueVector{Value()}; }
105 static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; }
106 static ::llvm::hash_code getHashValue(const ValueVector &val) {
107 return ::llvm::hash_combine_range(R: val);
108 }
109 static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
110 return LHS == RHS;
111 }
112};
113
114/// This class wraps a IRMapping to provide recursive lookup
115/// functionality, i.e. we will traverse if the mapped value also has a mapping.
116struct ConversionValueMapping {
117 /// Return "true" if an SSA value is mapped to the given value. May return
118 /// false positives.
119 bool isMappedTo(Value value) const { return mappedTo.contains(V: value); }
120
121 /// Lookup the most recently mapped values with the desired types in the
122 /// mapping.
123 ///
124 /// Special cases:
125 /// - If the desired type range is empty, simply return the most recently
126 /// mapped values.
127 /// - If there is no mapping to the desired types, also return the most
128 /// recently mapped values.
129 /// - If there is no mapping for the given values at all, return the given
130 /// value.
131 ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
132
133 /// Lookup the given value within the map, or return an empty vector if the
134 /// value is not mapped. If it is mapped, this follows the same behavior
135 /// as `lookupOrDefault`.
136 ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
137
138 template <typename T>
139 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
140
141 /// Map a value vector to the one provided.
142 template <typename OldVal, typename NewVal>
143 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
144 map(OldVal &&oldVal, NewVal &&newVal) {
145 LLVM_DEBUG({
146 ValueVector next(newVal);
147 while (true) {
148 assert(next != oldVal && "inserting cyclic mapping");
149 auto it = mapping.find(next);
150 if (it == mapping.end())
151 break;
152 next = it->second;
153 }
154 });
155 mappedTo.insert_range(newVal);
156
157 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
158 }
159
160 /// Map a value vector or single value to the one provided.
161 template <typename OldVal, typename NewVal>
162 std::enable_if_t<!IsValueVector<OldVal>::value ||
163 !IsValueVector<NewVal>::value>
164 map(OldVal &&oldVal, NewVal &&newVal) {
165 if constexpr (IsValueVector<OldVal>{}) {
166 map(std::forward<OldVal>(oldVal), ValueVector{newVal});
167 } else if constexpr (IsValueVector<NewVal>{}) {
168 map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
169 } else {
170 map(oldVal: ValueVector{oldVal}, newVal: ValueVector{newVal});
171 }
172 }
173
174 void map(Value oldVal, SmallVector<Value> &&newVal) {
175 map(oldVal: ValueVector{oldVal}, newVal: ValueVector(std::move(newVal)));
176 }
177
178 /// Drop the last mapping for the given values.
179 void erase(const ValueVector &value) { mapping.erase(Val: value); }
180
181private:
182 /// Current value mappings.
183 DenseMap<ValueVector, ValueVector, ValueVectorMapInfo> mapping;
184
185 /// All SSA values that are mapped to. May contain false positives.
186 DenseSet<Value> mappedTo;
187};
188} // namespace
189
190ValueVector
191ConversionValueMapping::lookupOrDefault(Value from,
192 TypeRange desiredTypes) const {
193 // Try to find the deepest values that have the desired types. If there is no
194 // such mapping, simply return the deepest values.
195 ValueVector desiredValue;
196 ValueVector current{from};
197 do {
198 // Store the current value if the types match.
199 if (TypeRange(ValueRange(current)) == desiredTypes)
200 desiredValue = current;
201
202 // If possible, Replace each value with (one or multiple) mapped values.
203 ValueVector next;
204 for (Value v : current) {
205 auto it = mapping.find(Val: {v});
206 if (it != mapping.end()) {
207 llvm::append_range(C&: next, R: it->second);
208 } else {
209 next.push_back(Elt: v);
210 }
211 }
212 if (next != current) {
213 // If at least one value was replaced, continue the lookup from there.
214 current = std::move(next);
215 continue;
216 }
217
218 // Otherwise: Check if there is a mapping for the entire vector. Such
219 // mappings are materializations. (N:M mapping are not supported for value
220 // replacements.)
221 //
222 // Note: From a correctness point of view, materializations do not have to
223 // be stored (and looked up) in the mapping. But for performance reasons,
224 // we choose to reuse existing IR (when possible) instead of creating it
225 // multiple times.
226 auto it = mapping.find(Val: current);
227 if (it == mapping.end()) {
228 // No mapping found: The lookup stops here.
229 break;
230 }
231 current = it->second;
232 } while (true);
233
234 // If the desired values were found use them, otherwise default to the leaf
235 // values.
236 // Note: If `desiredTypes` is empty, this function always returns `current`.
237 return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
238}
239
240ValueVector ConversionValueMapping::lookupOrNull(Value from,
241 TypeRange desiredTypes) const {
242 ValueVector result = lookupOrDefault(from, desiredTypes);
243 if (result == ValueVector{from} ||
244 (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
245 return {};
246 return result;
247}
248
249//===----------------------------------------------------------------------===//
250// Rewriter and Translation State
251//===----------------------------------------------------------------------===//
252namespace {
253/// This class contains a snapshot of the current conversion rewriter state.
254/// This is useful when saving and undoing a set of rewrites.
255struct RewriterState {
256 RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
257 unsigned numReplacedOps)
258 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
259 numReplacedOps(numReplacedOps) {}
260
261 /// The current number of rewrites performed.
262 unsigned numRewrites;
263
264 /// The current number of ignored operations.
265 unsigned numIgnoredOperations;
266
267 /// The current number of replaced ops that are scheduled for erasure.
268 unsigned numReplacedOps;
269};
270
271//===----------------------------------------------------------------------===//
272// IR rewrites
273//===----------------------------------------------------------------------===//
274
275static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
276
277/// Notify the listener that the given block and its contents are being erased.
278static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
279 for (Operation &op : b)
280 notifyIRErased(listener, op);
281 listener->notifyBlockErased(block: &b);
282}
283
284/// Notify the listener that the given operation and its contents are being
285/// erased.
286static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
287 for (Region &r : op.getRegions()) {
288 for (Block &b : r) {
289 notifyIRErased(listener, b);
290 }
291 }
292 listener->notifyOperationErased(op: &op);
293}
294
295/// An IR rewrite that can be committed (upon success) or rolled back (upon
296/// failure).
297///
298/// The dialect conversion keeps track of IR modifications (requested by the
299/// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
300/// are directly applied to the IR as the rewriter API is used, some are applied
301/// partially, and some are delayed until the `IRRewrite` objects are committed.
302class IRRewrite {
303public:
304 /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
305 /// Enum values are ordered, so that they can be used in `classof`: first all
306 /// block rewrites, then all operation rewrites.
307 enum class Kind {
308 // Block rewrites
309 CreateBlock,
310 EraseBlock,
311 InlineBlock,
312 MoveBlock,
313 BlockTypeConversion,
314 ReplaceBlockArg,
315 // Operation rewrites
316 MoveOperation,
317 ModifyOperation,
318 ReplaceOperation,
319 CreateOperation,
320 UnresolvedMaterialization
321 };
322
323 virtual ~IRRewrite() = default;
324
325 /// Roll back the rewrite. Operations may be erased during rollback.
326 virtual void rollback() = 0;
327
328 /// Commit the rewrite. At this point, it is certain that the dialect
329 /// conversion will succeed. All IR modifications, except for operation/block
330 /// erasure, must be performed through the given rewriter.
331 ///
332 /// Instead of erasing operations/blocks, they should merely be unlinked
333 /// commit phase and finally be erased during the cleanup phase. This is
334 /// because internal dialect conversion state (such as `mapping`) may still
335 /// be using them.
336 ///
337 /// Any IR modification that was already performed before the commit phase
338 /// (e.g., insertion of an op) must be communicated to the listener that may
339 /// be attached to the given rewriter.
340 virtual void commit(RewriterBase &rewriter) {}
341
342 /// Cleanup operations/blocks. Cleanup is called after commit.
343 virtual void cleanup(RewriterBase &rewriter) {}
344
345 Kind getKind() const { return kind; }
346
347 static bool classof(const IRRewrite *rewrite) { return true; }
348
349protected:
350 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
351 : kind(kind), rewriterImpl(rewriterImpl) {}
352
353 const ConversionConfig &getConfig() const;
354
355 const Kind kind;
356 ConversionPatternRewriterImpl &rewriterImpl;
357};
358
359/// A block rewrite.
360class BlockRewrite : public IRRewrite {
361public:
362 /// Return the block that this rewrite operates on.
363 Block *getBlock() const { return block; }
364
365 static bool classof(const IRRewrite *rewrite) {
366 return rewrite->getKind() >= Kind::CreateBlock &&
367 rewrite->getKind() <= Kind::ReplaceBlockArg;
368 }
369
370protected:
371 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
372 Block *block)
373 : IRRewrite(kind, rewriterImpl), block(block) {}
374
375 // The block that this rewrite operates on.
376 Block *block;
377};
378
379/// Creation of a block. Block creations are immediately reflected in the IR.
380/// There is no extra work to commit the rewrite. During rollback, the newly
381/// created block is erased.
382class CreateBlockRewrite : public BlockRewrite {
383public:
384 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
385 : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
386
387 static bool classof(const IRRewrite *rewrite) {
388 return rewrite->getKind() == Kind::CreateBlock;
389 }
390
391 void commit(RewriterBase &rewriter) override {
392 // The block was already created and inserted. Just inform the listener.
393 if (auto *listener = rewriter.getListener())
394 listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
395 }
396
397 void rollback() override {
398 // Unlink all of the operations within this block, they will be deleted
399 // separately.
400 auto &blockOps = block->getOperations();
401 while (!blockOps.empty())
402 blockOps.remove(IT: blockOps.begin());
403 block->dropAllUses();
404 if (block->getParent())
405 block->erase();
406 else
407 delete block;
408 }
409};
410
411/// Erasure of a block. Block erasures are partially reflected in the IR. Erased
412/// blocks are immediately unlinked, but only erased during cleanup. This makes
413/// it easier to rollback a block erasure: the block is simply inserted into its
414/// original location.
415class EraseBlockRewrite : public BlockRewrite {
416public:
417 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
418 : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
419 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
420
421 static bool classof(const IRRewrite *rewrite) {
422 return rewrite->getKind() == Kind::EraseBlock;
423 }
424
425 ~EraseBlockRewrite() override {
426 assert(!block &&
427 "rewrite was neither rolled back nor committed/cleaned up");
428 }
429
430 void rollback() override {
431 // The block (owned by this rewrite) was not actually erased yet. It was
432 // just unlinked. Put it back into its original position.
433 assert(block && "expected block");
434 auto &blockList = region->getBlocks();
435 Region::iterator before = insertBeforeBlock
436 ? Region::iterator(insertBeforeBlock)
437 : blockList.end();
438 blockList.insert(where: before, New: block);
439 block = nullptr;
440 }
441
442 void commit(RewriterBase &rewriter) override {
443 assert(block && "expected block");
444
445 // Notify the listener that the block and its contents are being erased.
446 if (auto *listener =
447 dyn_cast_or_null<RewriterBase::Listener>(Val: rewriter.getListener()))
448 notifyIRErased(listener, b&: *block);
449 }
450
451 void cleanup(RewriterBase &rewriter) override {
452 // Erase the contents of the block.
453 for (auto &op : llvm::make_early_inc_range(Range: llvm::reverse(C&: *block)))
454 rewriter.eraseOp(op: &op);
455 assert(block->empty() && "expected empty block");
456
457 // Erase the block.
458 block->dropAllDefinedValueUses();
459 delete block;
460 block = nullptr;
461 }
462
463private:
464 // The region in which this block was previously contained.
465 Region *region;
466
467 // The original successor of this block before it was unlinked. "nullptr" if
468 // this block was the only block in the region.
469 Block *insertBeforeBlock;
470};
471
472/// Inlining of a block. This rewrite is immediately reflected in the IR.
473/// Note: This rewrite represents only the inlining of the operations. The
474/// erasure of the inlined block is a separate rewrite.
475class InlineBlockRewrite : public BlockRewrite {
476public:
477 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
478 Block *sourceBlock, Block::iterator before)
479 : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
480 sourceBlock(sourceBlock),
481 firstInlinedInst(sourceBlock->empty() ? nullptr
482 : &sourceBlock->front()),
483 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
484 // If a listener is attached to the dialect conversion, ops must be moved
485 // one-by-one. When they are moved in bulk, notifications cannot be sent
486 // because the ops that used to be in the source block at the time of the
487 // inlining (before the "commit" phase) are unknown at the time when
488 // notifications are sent (which is during the "commit" phase).
489 assert(!getConfig().listener &&
490 "InlineBlockRewrite not supported if listener is attached");
491 }
492
493 static bool classof(const IRRewrite *rewrite) {
494 return rewrite->getKind() == Kind::InlineBlock;
495 }
496
497 void rollback() override {
498 // Put the operations from the destination block (owned by the rewrite)
499 // back into the source block.
500 if (firstInlinedInst) {
501 assert(lastInlinedInst && "expected operation");
502 sourceBlock->getOperations().splice(where: sourceBlock->begin(),
503 L2&: block->getOperations(),
504 first: Block::iterator(firstInlinedInst),
505 last: ++Block::iterator(lastInlinedInst));
506 }
507 }
508
509private:
510 // The block that originally contained the operations.
511 Block *sourceBlock;
512
513 // The first inlined operation.
514 Operation *firstInlinedInst;
515
516 // The last inlined operation.
517 Operation *lastInlinedInst;
518};
519
520/// Moving of a block. This rewrite is immediately reflected in the IR.
521class MoveBlockRewrite : public BlockRewrite {
522public:
523 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
524 Region *region, Block *insertBeforeBlock)
525 : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
526 insertBeforeBlock(insertBeforeBlock) {}
527
528 static bool classof(const IRRewrite *rewrite) {
529 return rewrite->getKind() == Kind::MoveBlock;
530 }
531
532 void commit(RewriterBase &rewriter) override {
533 // The block was already moved. Just inform the listener.
534 if (auto *listener = rewriter.getListener()) {
535 // Note: `previousIt` cannot be passed because this is a delayed
536 // notification and iterators into past IR state cannot be represented.
537 listener->notifyBlockInserted(block, /*previous=*/region,
538 /*previousIt=*/{});
539 }
540 }
541
542 void rollback() override {
543 // Move the block back to its original position.
544 Region::iterator before =
545 insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
546 region->getBlocks().splice(where: before, L2&: block->getParent()->getBlocks(), N: block);
547 }
548
549private:
550 // The region in which this block was previously contained.
551 Region *region;
552
553 // The original successor of this block before it was moved. "nullptr" if
554 // this block was the only block in the region.
555 Block *insertBeforeBlock;
556};
557
558/// Block type conversion. This rewrite is partially reflected in the IR.
559class BlockTypeConversionRewrite : public BlockRewrite {
560public:
561 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
562 Block *origBlock, Block *newBlock)
563 : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
564 newBlock(newBlock) {}
565
566 static bool classof(const IRRewrite *rewrite) {
567 return rewrite->getKind() == Kind::BlockTypeConversion;
568 }
569
570 Block *getOrigBlock() const { return block; }
571
572 Block *getNewBlock() const { return newBlock; }
573
574 void commit(RewriterBase &rewriter) override;
575
576 void rollback() override;
577
578private:
579 /// The new block that was created as part of this signature conversion.
580 Block *newBlock;
581};
582
583/// Replacing a block argument. This rewrite is not immediately reflected in the
584/// IR. An internal IR mapping is updated, but the actual replacement is delayed
585/// until the rewrite is committed.
586class ReplaceBlockArgRewrite : public BlockRewrite {
587public:
588 ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
589 Block *block, BlockArgument arg,
590 const TypeConverter *converter)
591 : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
592 converter(converter) {}
593
594 static bool classof(const IRRewrite *rewrite) {
595 return rewrite->getKind() == Kind::ReplaceBlockArg;
596 }
597
598 void commit(RewriterBase &rewriter) override;
599
600 void rollback() override;
601
602private:
603 BlockArgument arg;
604
605 /// The current type converter when the block argument was replaced.
606 const TypeConverter *converter;
607};
608
609/// An operation rewrite.
610class OperationRewrite : public IRRewrite {
611public:
612 /// Return the operation that this rewrite operates on.
613 Operation *getOperation() const { return op; }
614
615 static bool classof(const IRRewrite *rewrite) {
616 return rewrite->getKind() >= Kind::MoveOperation &&
617 rewrite->getKind() <= Kind::UnresolvedMaterialization;
618 }
619
620protected:
621 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
622 Operation *op)
623 : IRRewrite(kind, rewriterImpl), op(op) {}
624
625 // The operation that this rewrite operates on.
626 Operation *op;
627};
628
629/// Moving of an operation. This rewrite is immediately reflected in the IR.
630class MoveOperationRewrite : public OperationRewrite {
631public:
632 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
633 Operation *op, Block *block, Operation *insertBeforeOp)
634 : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
635 insertBeforeOp(insertBeforeOp) {}
636
637 static bool classof(const IRRewrite *rewrite) {
638 return rewrite->getKind() == Kind::MoveOperation;
639 }
640
641 void commit(RewriterBase &rewriter) override {
642 // The operation was already moved. Just inform the listener.
643 if (auto *listener = rewriter.getListener()) {
644 // Note: `previousIt` cannot be passed because this is a delayed
645 // notification and iterators into past IR state cannot be represented.
646 listener->notifyOperationInserted(
647 op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
648 /*insertPt=*/{}));
649 }
650 }
651
652 void rollback() override {
653 // Move the operation back to its original position.
654 Block::iterator before =
655 insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
656 block->getOperations().splice(where: before, L2&: op->getBlock()->getOperations(), N: op);
657 }
658
659private:
660 // The block in which this operation was previously contained.
661 Block *block;
662
663 // The original successor of this operation before it was moved. "nullptr"
664 // if this operation was the only operation in the region.
665 Operation *insertBeforeOp;
666};
667
668/// In-place modification of an op. This rewrite is immediately reflected in
669/// the IR. The previous state of the operation is stored in this object.
670class ModifyOperationRewrite : public OperationRewrite {
671public:
672 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
673 Operation *op)
674 : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
675 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
676 operands(op->operand_begin(), op->operand_end()),
677 successors(op->successor_begin(), op->successor_end()) {
678 if (OpaqueProperties prop = op->getPropertiesStorage()) {
679 // Make a copy of the properties.
680 propertiesStorage = operator new(op->getPropertiesStorageSize());
681 OpaqueProperties propCopy(propertiesStorage);
682 name.initOpProperties(storage: propCopy, /*init=*/prop);
683 }
684 }
685
686 static bool classof(const IRRewrite *rewrite) {
687 return rewrite->getKind() == Kind::ModifyOperation;
688 }
689
690 ~ModifyOperationRewrite() override {
691 assert(!propertiesStorage &&
692 "rewrite was neither committed nor rolled back");
693 }
694
695 void commit(RewriterBase &rewriter) override {
696 // Notify the listener that the operation was modified in-place.
697 if (auto *listener =
698 dyn_cast_or_null<RewriterBase::Listener>(Val: rewriter.getListener()))
699 listener->notifyOperationModified(op);
700
701 if (propertiesStorage) {
702 OpaqueProperties propCopy(propertiesStorage);
703 // Note: The operation may have been erased in the mean time, so
704 // OperationName must be stored in this object.
705 name.destroyOpProperties(properties: propCopy);
706 operator delete(propertiesStorage);
707 propertiesStorage = nullptr;
708 }
709 }
710
711 void rollback() override {
712 op->setLoc(loc);
713 op->setAttrs(attrs);
714 op->setOperands(operands);
715 for (const auto &it : llvm::enumerate(First&: successors))
716 op->setSuccessor(block: it.value(), index: it.index());
717 if (propertiesStorage) {
718 OpaqueProperties propCopy(propertiesStorage);
719 op->copyProperties(rhs: propCopy);
720 name.destroyOpProperties(properties: propCopy);
721 operator delete(propertiesStorage);
722 propertiesStorage = nullptr;
723 }
724 }
725
726private:
727 OperationName name;
728 LocationAttr loc;
729 DictionaryAttr attrs;
730 SmallVector<Value, 8> operands;
731 SmallVector<Block *, 2> successors;
732 void *propertiesStorage = nullptr;
733};
734
735/// Replacing an operation. Erasing an operation is treated as a special case
736/// with "null" replacements. This rewrite is not immediately reflected in the
737/// IR. An internal IR mapping is updated, but values are not replaced and the
738/// original op is not erased until the rewrite is committed.
739class ReplaceOperationRewrite : public OperationRewrite {
740public:
741 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
742 Operation *op, const TypeConverter *converter)
743 : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
744 converter(converter) {}
745
746 static bool classof(const IRRewrite *rewrite) {
747 return rewrite->getKind() == Kind::ReplaceOperation;
748 }
749
750 void commit(RewriterBase &rewriter) override;
751
752 void rollback() override;
753
754 void cleanup(RewriterBase &rewriter) override;
755
756private:
757 /// An optional type converter that can be used to materialize conversions
758 /// between the new and old values if necessary.
759 const TypeConverter *converter;
760};
761
762class CreateOperationRewrite : public OperationRewrite {
763public:
764 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
765 Operation *op)
766 : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
767
768 static bool classof(const IRRewrite *rewrite) {
769 return rewrite->getKind() == Kind::CreateOperation;
770 }
771
772 void commit(RewriterBase &rewriter) override {
773 // The operation was already created and inserted. Just inform the listener.
774 if (auto *listener = rewriter.getListener())
775 listener->notifyOperationInserted(op, /*previous=*/{});
776 }
777
778 void rollback() override;
779};
780
781/// The type of materialization.
782enum MaterializationKind {
783 /// This materialization materializes a conversion from an illegal type to a
784 /// legal one.
785 Target,
786
787 /// This materialization materializes a conversion from a legal type back to
788 /// an illegal one.
789 Source
790};
791
792/// Helper class that stores metadata about an unresolved materialization.
793class UnresolvedMaterializationInfo {
794public:
795 UnresolvedMaterializationInfo() = default;
796 UnresolvedMaterializationInfo(const TypeConverter *converter,
797 MaterializationKind kind, Type originalType)
798 : converterAndKind(converter, kind), originalType(originalType) {}
799
800 /// Return the type converter of this materialization (which may be null).
801 const TypeConverter *getConverter() const {
802 return converterAndKind.getPointer();
803 }
804
805 /// Return the kind of this materialization.
806 MaterializationKind getMaterializationKind() const {
807 return converterAndKind.getInt();
808 }
809
810 /// Return the original type of the SSA value.
811 Type getOriginalType() const { return originalType; }
812
813private:
814 /// The corresponding type converter to use when resolving this
815 /// materialization, and the kind of this materialization.
816 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
817 converterAndKind;
818
819 /// The original type of the SSA value. Only used for target
820 /// materializations.
821 Type originalType;
822};
823
824/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
825/// op. Unresolved materializations fold away or are replaced with
826/// source/target materializations at the end of the dialect conversion.
827class UnresolvedMaterializationRewrite : public OperationRewrite {
828public:
829 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
830 UnrealizedConversionCastOp op,
831 ValueVector mappedValues)
832 : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
833 mappedValues(std::move(mappedValues)) {}
834
835 static bool classof(const IRRewrite *rewrite) {
836 return rewrite->getKind() == Kind::UnresolvedMaterialization;
837 }
838
839 void rollback() override;
840
841 UnrealizedConversionCastOp getOperation() const {
842 return cast<UnrealizedConversionCastOp>(Val: op);
843 }
844
845private:
846 /// The values in the conversion value mapping that are being replaced by the
847 /// results of this unresolved materialization.
848 ValueVector mappedValues;
849};
850} // namespace
851
852#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
853/// Return "true" if there is an operation rewrite that matches the specified
854/// rewrite type and operation among the given rewrites.
855template <typename RewriteTy, typename R>
856static bool hasRewrite(R &&rewrites, Operation *op) {
857 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
858 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
859 return rewriteTy && rewriteTy->getOperation() == op;
860 });
861}
862
863/// Return "true" if there is a block rewrite that matches the specified
864/// rewrite type and block among the given rewrites.
865template <typename RewriteTy, typename R>
866static bool hasRewrite(R &&rewrites, Block *block) {
867 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
868 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
869 return rewriteTy && rewriteTy->getBlock() == block;
870 });
871}
872#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
873
874//===----------------------------------------------------------------------===//
875// ConversionPatternRewriterImpl
876//===----------------------------------------------------------------------===//
877namespace mlir {
878namespace detail {
879struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
880 explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
881 const ConversionConfig &config)
882 : context(ctx), config(config) {}
883
884 //===--------------------------------------------------------------------===//
885 // State Management
886 //===--------------------------------------------------------------------===//
887
888 /// Return the current state of the rewriter.
889 RewriterState getCurrentState();
890
891 /// Apply all requested operation rewrites. This method is invoked when the
892 /// conversion process succeeds.
893 void applyRewrites();
894
895 /// Reset the state of the rewriter to a previously saved point. Optionally,
896 /// the name of the pattern that triggered the rollback can specified for
897 /// debugging purposes.
898 void resetState(RewriterState state, StringRef patternName = "");
899
900 /// Append a rewrite. Rewrites are committed upon success and rolled back upon
901 /// failure.
902 template <typename RewriteTy, typename... Args>
903 void appendRewrite(Args &&...args) {
904 rewrites.push_back(
905 std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
906 }
907
908 /// Undo the rewrites (motions, splits) one by one in reverse order until
909 /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
910 /// that triggered the rollback can specified for debugging purposes.
911 void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
912
913 /// Remap the given values to those with potentially different types. Returns
914 /// success if the values could be remapped, failure otherwise. `valueDiagTag`
915 /// is the tag used when describing a value within a diagnostic, e.g.
916 /// "operand".
917 LogicalResult remapValues(StringRef valueDiagTag,
918 std::optional<Location> inputLoc,
919 PatternRewriter &rewriter, ValueRange values,
920 SmallVector<ValueVector> &remapped);
921
922 /// Return "true" if the given operation is ignored, and does not need to be
923 /// converted.
924 bool isOpIgnored(Operation *op) const;
925
926 /// Return "true" if the given operation was replaced or erased.
927 bool wasOpReplaced(Operation *op) const;
928
929 //===--------------------------------------------------------------------===//
930 // IR Rewrites / Type Conversion
931 //===--------------------------------------------------------------------===//
932
933 /// Convert the types of block arguments within the given region.
934 FailureOr<Block *>
935 convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
936 const TypeConverter &converter,
937 TypeConverter::SignatureConversion *entryConversion);
938
939 /// Apply the given signature conversion on the given block. The new block
940 /// containing the updated signature is returned. If no conversions were
941 /// necessary, e.g. if the block has no arguments, `block` is returned.
942 /// `converter` is used to generate any necessary cast operations that
943 /// translate between the origin argument types and those specified in the
944 /// signature conversion.
945 Block *applySignatureConversion(
946 ConversionPatternRewriter &rewriter, Block *block,
947 const TypeConverter *converter,
948 TypeConverter::SignatureConversion &signatureConversion);
949
950 /// Replace the results of the given operation with the given values and
951 /// erase the operation.
952 ///
953 /// There can be multiple replacement values for each result (1:N
954 /// replacement). If the replacement values are empty, the respective result
955 /// is dropped and a source materialization is built if the result still has
956 /// uses.
957 void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
958
959 /// Replace the given block argument with the given values. The specified
960 /// converter is used to build materializations (if necessary).
961 void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
962 const TypeConverter *converter);
963
964 /// Erase the given block and its contents.
965 void eraseBlock(Block *block);
966
967 /// Inline the source block into the destination block before the given
968 /// iterator.
969 void inlineBlockBefore(Block *source, Block *dest, Block::iterator before);
970
971 //===--------------------------------------------------------------------===//
972 // Materializations
973 //===--------------------------------------------------------------------===//
974
975 /// Build an unresolved materialization operation given a range of output
976 /// types and a list of input operands. Returns the inputs if they their
977 /// types match the output types.
978 ///
979 /// If a cast op was built, it can optionally be returned with the `castOp`
980 /// output argument.
981 ///
982 /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
983 /// the results of the unresolved materialization in the conversion value
984 /// mapping.
985 ValueRange buildUnresolvedMaterialization(
986 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
987 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
988 Type originalType, const TypeConverter *converter,
989 UnrealizedConversionCastOp *castOp = nullptr);
990
991 /// Find a replacement value for the given SSA value in the conversion value
992 /// mapping. The replacement value must have the same type as the given SSA
993 /// value. If there is no replacement value with the correct type, find the
994 /// latest replacement value (regardless of the type) and build a source
995 /// materialization.
996 Value findOrBuildReplacementValue(Value value,
997 const TypeConverter *converter);
998
999 //===--------------------------------------------------------------------===//
1000 // Rewriter Notification Hooks
1001 //===--------------------------------------------------------------------===//
1002
1003 //// Notifies that an op was inserted.
1004 void notifyOperationInserted(Operation *op,
1005 OpBuilder::InsertPoint previous) override;
1006
1007 /// Notifies that a block was inserted.
1008 void notifyBlockInserted(Block *block, Region *previous,
1009 Region::iterator previousIt) override;
1010
1011 /// Notifies that a pattern match failed for the given reason.
1012 void
1013 notifyMatchFailure(Location loc,
1014 function_ref<void(Diagnostic &)> reasonCallback) override;
1015
1016 //===--------------------------------------------------------------------===//
1017 // IR Erasure
1018 //===--------------------------------------------------------------------===//
1019
1020 /// A rewriter that keeps track of erased ops and blocks. It ensures that no
1021 /// operation or block is erased multiple times. This rewriter assumes that
1022 /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
1023 struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
1024 public:
1025 SingleEraseRewriter(
1026 MLIRContext *context,
1027 std::function<void(Operation *)> opErasedCallback = nullptr)
1028 : RewriterBase(context, /*listener=*/this),
1029 opErasedCallback(opErasedCallback) {}
1030
1031 /// Erase the given op (unless it was already erased).
1032 void eraseOp(Operation *op) override {
1033 if (wasErased(ptr: op))
1034 return;
1035 op->dropAllUses();
1036 RewriterBase::eraseOp(op);
1037 }
1038
1039 /// Erase the given block (unless it was already erased).
1040 void eraseBlock(Block *block) override {
1041 if (wasErased(ptr: block))
1042 return;
1043 assert(block->empty() && "expected empty block");
1044 block->dropAllDefinedValueUses();
1045 RewriterBase::eraseBlock(block);
1046 }
1047
1048 bool wasErased(void *ptr) const { return erased.contains(V: ptr); }
1049
1050 void notifyOperationErased(Operation *op) override {
1051 erased.insert(V: op);
1052 if (opErasedCallback)
1053 opErasedCallback(op);
1054 }
1055
1056 void notifyBlockErased(Block *block) override { erased.insert(V: block); }
1057
1058 private:
1059 /// Pointers to all erased operations and blocks.
1060 DenseSet<void *> erased;
1061
1062 /// A callback that is invoked when an operation is erased.
1063 std::function<void(Operation *)> opErasedCallback;
1064 };
1065
1066 //===--------------------------------------------------------------------===//
1067 // State
1068 //===--------------------------------------------------------------------===//
1069
1070 /// MLIR context.
1071 MLIRContext *context;
1072
1073 // Mapping between replaced values that differ in type. This happens when
1074 // replacing a value with one of a different type.
1075 ConversionValueMapping mapping;
1076
1077 /// Ordered list of block operations (creations, splits, motions).
1078 SmallVector<std::unique_ptr<IRRewrite>> rewrites;
1079
1080 /// A set of operations that should no longer be considered for legalization.
1081 /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1082 /// tracked separately.
1083 SetVector<Operation *> ignoredOps;
1084
1085 /// A set of operations that were replaced/erased. Such ops are not erased
1086 /// immediately but only when the dialect conversion succeeds. In the mean
1087 /// time, they should no longer be considered for legalization and any attempt
1088 /// to modify/access them is invalid rewriter API usage.
1089 SetVector<Operation *> replacedOps;
1090
1091 /// A set of operations that were created by the current pattern.
1092 SetVector<Operation *> patternNewOps;
1093
1094 /// A set of operations that were modified by the current pattern.
1095 SetVector<Operation *> patternModifiedOps;
1096
1097 /// A set of blocks that were inserted (newly-created blocks or moved blocks)
1098 /// by the current pattern.
1099 SetVector<Block *> patternInsertedBlocks;
1100
1101 /// A mapping for looking up metadata of unresolved materializations.
1102 DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
1103 unresolvedMaterializations;
1104
1105 /// The current type converter, or nullptr if no type converter is currently
1106 /// active.
1107 const TypeConverter *currentTypeConverter = nullptr;
1108
1109 /// A mapping of regions to type converters that should be used when
1110 /// converting the arguments of blocks within that region.
1111 DenseMap<Region *, const TypeConverter *> regionToConverter;
1112
1113 /// Dialect conversion configuration.
1114 const ConversionConfig &config;
1115
1116#ifndef NDEBUG
1117 /// A set of operations that have pending updates. This tracking isn't
1118 /// strictly necessary, and is thus only active during debug builds for extra
1119 /// verification.
1120 SmallPtrSet<Operation *, 1> pendingRootUpdates;
1121
1122 /// A logger used to emit diagnostics during the conversion process.
1123 llvm::ScopedPrinter logger{llvm::dbgs()};
1124#endif
1125};
1126} // namespace detail
1127} // namespace mlir
1128
1129const ConversionConfig &IRRewrite::getConfig() const {
1130 return rewriterImpl.config;
1131}
1132
1133void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1134 // Inform the listener about all IR modifications that have already taken
1135 // place: References to the original block have been replaced with the new
1136 // block.
1137 if (auto *listener =
1138 dyn_cast_or_null<RewriterBase::Listener>(Val: rewriter.getListener()))
1139 for (Operation *op : getNewBlock()->getUsers())
1140 listener->notifyOperationModified(op);
1141}
1142
1143void BlockTypeConversionRewrite::rollback() {
1144 getNewBlock()->replaceAllUsesWith(newValue: getOrigBlock());
1145}
1146
1147void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1148 Value repl = rewriterImpl.findOrBuildReplacementValue(value: arg, converter);
1149 if (!repl)
1150 return;
1151
1152 if (isa<BlockArgument>(Val: repl)) {
1153 rewriter.replaceAllUsesWith(from: arg, to: repl);
1154 return;
1155 }
1156
1157 // If the replacement value is an operation, we check to make sure that we
1158 // don't replace uses that are within the parent operation of the
1159 // replacement value.
1160 Operation *replOp = cast<OpResult>(Val&: repl).getOwner();
1161 Block *replBlock = replOp->getBlock();
1162 rewriter.replaceUsesWithIf(from: arg, to: repl, functor: [&](OpOperand &operand) {
1163 Operation *user = operand.getOwner();
1164 return user->getBlock() != replBlock || replOp->isBeforeInBlock(other: user);
1165 });
1166}
1167
1168void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(value: {arg}); }
1169
1170void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1171 auto *listener =
1172 dyn_cast_or_null<RewriterBase::Listener>(Val: rewriter.getListener());
1173
1174 // Compute replacement values.
1175 SmallVector<Value> replacements =
1176 llvm::map_to_vector(C: op->getResults(), F: [&](OpResult result) {
1177 return rewriterImpl.findOrBuildReplacementValue(value: result, converter);
1178 });
1179
1180 // Notify the listener that the operation is about to be replaced.
1181 if (listener)
1182 listener->notifyOperationReplaced(op, replacement: replacements);
1183
1184 // Replace all uses with the new values.
1185 for (auto [result, newValue] :
1186 llvm::zip_equal(t: op->getResults(), u&: replacements))
1187 if (newValue)
1188 rewriter.replaceAllUsesWith(from: result, to: newValue);
1189
1190 // The original op will be erased, so remove it from the set of unlegalized
1191 // ops.
1192 if (getConfig().unlegalizedOps)
1193 getConfig().unlegalizedOps->erase(V: op);
1194
1195 // Notify the listener that the operation and its contents are being erased.
1196 if (listener)
1197 notifyIRErased(listener, op&: *op);
1198
1199 // Do not erase the operation yet. It may still be referenced in `mapping`.
1200 // Just unlink it for now and erase it during cleanup.
1201 op->getBlock()->getOperations().remove(IT: op);
1202}
1203
1204void ReplaceOperationRewrite::rollback() {
1205 for (auto result : op->getResults())
1206 rewriterImpl.mapping.erase(value: {result});
1207}
1208
1209void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1210 rewriter.eraseOp(op);
1211}
1212
1213void CreateOperationRewrite::rollback() {
1214 for (Region &region : op->getRegions()) {
1215 while (!region.getBlocks().empty())
1216 region.getBlocks().remove(IT: region.getBlocks().begin());
1217 }
1218 op->dropAllUses();
1219 op->erase();
1220}
1221
1222void UnresolvedMaterializationRewrite::rollback() {
1223 if (!mappedValues.empty())
1224 rewriterImpl.mapping.erase(value: mappedValues);
1225 rewriterImpl.unresolvedMaterializations.erase(Val: getOperation());
1226 op->erase();
1227}
1228
1229void ConversionPatternRewriterImpl::applyRewrites() {
1230 // Commit all rewrites.
1231 IRRewriter rewriter(context, config.listener);
1232 // Note: New rewrites may be added during the "commit" phase and the
1233 // `rewrites` vector may reallocate.
1234 for (size_t i = 0; i < rewrites.size(); ++i)
1235 rewrites[i]->commit(rewriter);
1236
1237 // Clean up all rewrites.
1238 SingleEraseRewriter eraseRewriter(
1239 context, /*opErasedCallback=*/[&](Operation *op) {
1240 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(Val: op))
1241 unresolvedMaterializations.erase(Val: castOp);
1242 });
1243 for (auto &rewrite : rewrites)
1244 rewrite->cleanup(rewriter&: eraseRewriter);
1245}
1246
1247//===----------------------------------------------------------------------===//
1248// State Management
1249//===----------------------------------------------------------------------===//
1250
1251RewriterState ConversionPatternRewriterImpl::getCurrentState() {
1252 return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1253}
1254
1255void ConversionPatternRewriterImpl::resetState(RewriterState state,
1256 StringRef patternName) {
1257 // Undo any rewrites.
1258 undoRewrites(numRewritesToKeep: state.numRewrites, patternName);
1259
1260 // Pop all of the recorded ignored operations that are no longer valid.
1261 while (ignoredOps.size() != state.numIgnoredOperations)
1262 ignoredOps.pop_back();
1263
1264 while (replacedOps.size() != state.numReplacedOps)
1265 replacedOps.pop_back();
1266}
1267
1268void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
1269 StringRef patternName) {
1270 for (auto &rewrite :
1271 llvm::reverse(C: llvm::drop_begin(RangeOrContainer&: rewrites, N: numRewritesToKeep))) {
1272 if (!config.allowPatternRollback &&
1273 !isa<UnresolvedMaterializationRewrite>(Val: rewrite)) {
1274 // Unresolved materializations can always be rolled back (erased).
1275 llvm::report_fatal_error(reason: "pattern '" + patternName +
1276 "' rollback of IR modifications requested");
1277 }
1278 rewrite->rollback();
1279 }
1280 rewrites.resize(N: numRewritesToKeep);
1281}
1282
1283LogicalResult ConversionPatternRewriterImpl::remapValues(
1284 StringRef valueDiagTag, std::optional<Location> inputLoc,
1285 PatternRewriter &rewriter, ValueRange values,
1286 SmallVector<ValueVector> &remapped) {
1287 remapped.reserve(N: llvm::size(Range&: values));
1288
1289 for (const auto &it : llvm::enumerate(First&: values)) {
1290 Value operand = it.value();
1291 Type origType = operand.getType();
1292 Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1293
1294 if (!currentTypeConverter) {
1295 // The current pattern does not have a type converter. I.e., it does not
1296 // distinguish between legal and illegal types. For each operand, simply
1297 // pass through the most recently mapped values.
1298 remapped.push_back(Elt: mapping.lookupOrDefault(from: operand));
1299 continue;
1300 }
1301
1302 // If there is no legal conversion, fail to match this pattern.
1303 SmallVector<Type, 1> legalTypes;
1304 if (failed(Result: currentTypeConverter->convertType(t: origType, results&: legalTypes))) {
1305 notifyMatchFailure(loc: operandLoc, reasonCallback: [=](Diagnostic &diag) {
1306 diag << "unable to convert type for " << valueDiagTag << " #"
1307 << it.index() << ", type was " << origType;
1308 });
1309 return failure();
1310 }
1311 // If a type is converted to 0 types, there is nothing to do.
1312 if (legalTypes.empty()) {
1313 remapped.push_back(Elt: {});
1314 continue;
1315 }
1316
1317 ValueVector repl = mapping.lookupOrDefault(from: operand, desiredTypes: legalTypes);
1318 if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1319 // Mapped values have the correct type or there is an existing
1320 // materialization. Or the operand is not mapped at all and has the
1321 // correct type.
1322 remapped.push_back(Elt: std::move(repl));
1323 continue;
1324 }
1325
1326 // Create a materialization for the most recently mapped values.
1327 repl = mapping.lookupOrDefault(from: operand);
1328 ValueRange castValues = buildUnresolvedMaterialization(
1329 kind: MaterializationKind::Target, ip: computeInsertPoint(vals: repl), loc: operandLoc,
1330 /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1331 /*originalType=*/origType, converter: currentTypeConverter);
1332 remapped.push_back(Elt: castValues);
1333 }
1334 return success();
1335}
1336
1337bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
1338 // Check to see if this operation is ignored or was replaced.
1339 return replacedOps.count(key: op) || ignoredOps.count(key: op);
1340}
1341
1342bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
1343 // Check to see if this operation was replaced.
1344 return replacedOps.count(key: op);
1345}
1346
1347//===----------------------------------------------------------------------===//
1348// Type Conversion
1349//===----------------------------------------------------------------------===//
1350
1351FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
1352 ConversionPatternRewriter &rewriter, Region *region,
1353 const TypeConverter &converter,
1354 TypeConverter::SignatureConversion *entryConversion) {
1355 regionToConverter[region] = &converter;
1356 if (region->empty())
1357 return nullptr;
1358
1359 // Convert the arguments of each non-entry block within the region.
1360 for (Block &block :
1361 llvm::make_early_inc_range(Range: llvm::drop_begin(RangeOrContainer&: *region, N: 1))) {
1362 // Compute the signature for the block with the provided converter.
1363 std::optional<TypeConverter::SignatureConversion> conversion =
1364 converter.convertBlockSignature(block: &block);
1365 if (!conversion)
1366 return failure();
1367 // Convert the block with the computed signature.
1368 applySignatureConversion(rewriter, block: &block, converter: &converter, signatureConversion&: *conversion);
1369 }
1370
1371 // Convert the entry block. If an entry signature conversion was provided,
1372 // use that one. Otherwise, compute the signature with the type converter.
1373 if (entryConversion)
1374 return applySignatureConversion(rewriter, block: &region->front(), converter: &converter,
1375 signatureConversion&: *entryConversion);
1376 std::optional<TypeConverter::SignatureConversion> conversion =
1377 converter.convertBlockSignature(block: &region->front());
1378 if (!conversion)
1379 return failure();
1380 return applySignatureConversion(rewriter, block: &region->front(), converter: &converter,
1381 signatureConversion&: *conversion);
1382}
1383
1384Block *ConversionPatternRewriterImpl::applySignatureConversion(
1385 ConversionPatternRewriter &rewriter, Block *block,
1386 const TypeConverter *converter,
1387 TypeConverter::SignatureConversion &signatureConversion) {
1388#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1389 // A block cannot be converted multiple times.
1390 if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1391 llvm::report_fatal_error("block was already converted");
1392#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1393
1394 OpBuilder::InsertionGuard g(rewriter);
1395
1396 // If no arguments are being changed or added, there is nothing to do.
1397 unsigned origArgCount = block->getNumArguments();
1398 auto convertedTypes = signatureConversion.getConvertedTypes();
1399 if (llvm::equal(LRange: block->getArgumentTypes(), RRange&: convertedTypes))
1400 return block;
1401
1402 // Compute the locations of all block arguments in the new block.
1403 SmallVector<Location> newLocs(convertedTypes.size(),
1404 rewriter.getUnknownLoc());
1405 for (unsigned i = 0; i < origArgCount; ++i) {
1406 auto inputMap = signatureConversion.getInputMapping(input: i);
1407 if (!inputMap || inputMap->replacedWithValues())
1408 continue;
1409 Location origLoc = block->getArgument(i).getLoc();
1410 for (unsigned j = 0; j < inputMap->size; ++j)
1411 newLocs[inputMap->inputNo + j] = origLoc;
1412 }
1413
1414 // Insert a new block with the converted block argument types and move all ops
1415 // from the old block to the new block.
1416 Block *newBlock =
1417 rewriter.createBlock(parent: block->getParent(), insertPt: std::next(x: block->getIterator()),
1418 argTypes: convertedTypes, locs: newLocs);
1419
1420 // If a listener is attached to the dialect conversion, ops cannot be moved
1421 // to the destination block in bulk ("fast path"). This is because at the time
1422 // the notifications are sent, it is unknown which ops were moved. Instead,
1423 // ops should be moved one-by-one ("slow path"), so that a separate
1424 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1425 // a bit more efficient, so we try to do that when possible.
1426 bool fastPath = !config.listener;
1427 if (fastPath) {
1428 appendRewrite<InlineBlockRewrite>(args&: newBlock, args&: block, args: newBlock->end());
1429 newBlock->getOperations().splice(where: newBlock->end(), L2&: block->getOperations());
1430 } else {
1431 while (!block->empty())
1432 rewriter.moveOpBefore(op: &block->front(), block: newBlock, iterator: newBlock->end());
1433 }
1434
1435 // Replace all uses of the old block with the new block.
1436 block->replaceAllUsesWith(newValue&: newBlock);
1437
1438 for (unsigned i = 0; i != origArgCount; ++i) {
1439 BlockArgument origArg = block->getArgument(i);
1440 Type origArgType = origArg.getType();
1441
1442 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1443 signatureConversion.getInputMapping(input: i);
1444 if (!inputMap) {
1445 // This block argument was dropped and no replacement value was provided.
1446 // Materialize a replacement value "out of thin air".
1447 Value mat =
1448 buildUnresolvedMaterialization(
1449 kind: MaterializationKind::Source,
1450 ip: OpBuilder::InsertPoint(newBlock, newBlock->begin()),
1451 loc: origArg.getLoc(),
1452 /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1453 /*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
1454 .front();
1455 replaceUsesOfBlockArgument(from: origArg, to: mat, converter);
1456 continue;
1457 }
1458
1459 if (inputMap->replacedWithValues()) {
1460 // This block argument was dropped and replacement values were provided.
1461 assert(inputMap->size == 0 &&
1462 "invalid to provide a replacement value when the argument isn't "
1463 "dropped");
1464 replaceUsesOfBlockArgument(from: origArg, to: inputMap->replacementValues,
1465 converter);
1466 continue;
1467 }
1468
1469 // This is a 1->1+ mapping.
1470 auto replArgs =
1471 newBlock->getArguments().slice(N: inputMap->inputNo, M: inputMap->size);
1472 replaceUsesOfBlockArgument(from: origArg, to: replArgs, converter);
1473 }
1474
1475 appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/args&: block, args&: newBlock);
1476
1477 // Erase the old block. (It is just unlinked for now and will be erased during
1478 // cleanup.)
1479 rewriter.eraseBlock(block);
1480
1481 return newBlock;
1482}
1483
1484//===----------------------------------------------------------------------===//
1485// Materializations
1486//===----------------------------------------------------------------------===//
1487
1488/// Build an unresolved materialization operation given an output type and set
1489/// of input operands.
1490ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1491 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1492 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1493 Type originalType, const TypeConverter *converter,
1494 UnrealizedConversionCastOp *castOp) {
1495 assert((!originalType || kind == MaterializationKind::Target) &&
1496 "original type is valid only for target materializations");
1497 assert(TypeRange(inputs) != outputTypes &&
1498 "materialization is not necessary");
1499
1500 // Create an unresolved materialization. We use a new OpBuilder to avoid
1501 // tracking the materialization like we do for other operations.
1502 OpBuilder builder(outputTypes.front().getContext());
1503 builder.setInsertionPoint(block: ip.getBlock(), insertPoint: ip.getPoint());
1504 auto convertOp =
1505 builder.create<UnrealizedConversionCastOp>(location: loc, args&: outputTypes, args&: inputs);
1506 if (!valuesToMap.empty())
1507 mapping.map(oldVal&: valuesToMap, newVal: convertOp.getResults());
1508 if (castOp)
1509 *castOp = convertOp;
1510 unresolvedMaterializations[convertOp] =
1511 UnresolvedMaterializationInfo(converter, kind, originalType);
1512 appendRewrite<UnresolvedMaterializationRewrite>(args&: convertOp,
1513 args: std::move(valuesToMap));
1514 return convertOp.getResults();
1515}
1516
1517Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
1518 Value value, const TypeConverter *converter) {
1519 // Try to find a replacement value with the same type in the conversion value
1520 // mapping. This includes cached materializations. We try to reuse those
1521 // instead of generating duplicate IR.
1522 ValueVector repl = mapping.lookupOrNull(from: value, desiredTypes: value.getType());
1523 if (!repl.empty())
1524 return repl.front();
1525
1526 // Check if the value is dead. No replacement value is needed in that case.
1527 // This is an approximate check that may have false negatives but does not
1528 // require computing and traversing an inverse mapping. (We may end up
1529 // building source materializations that are never used and that fold away.)
1530 if (llvm::all_of(Range: value.getUsers(),
1531 P: [&](Operation *op) { return replacedOps.contains(key: op); }) &&
1532 !mapping.isMappedTo(value))
1533 return Value();
1534
1535 // No replacement value was found. Get the latest replacement value
1536 // (regardless of the type) and build a source materialization to the
1537 // original type.
1538 repl = mapping.lookupOrNull(from: value);
1539 if (repl.empty()) {
1540 // No replacement value is registered in the mapping. This means that the
1541 // value is dropped and no longer needed. (If the value were still needed,
1542 // a source materialization producing a replacement value "out of thin air"
1543 // would have already been created during `replaceOp` or
1544 // `applySignatureConversion`.)
1545 return Value();
1546 }
1547
1548 // Note: `computeInsertPoint` computes the "earliest" insertion point at
1549 // which all values in `repl` are defined. It is important to emit the
1550 // materialization at that location because the same materialization may be
1551 // reused in a different context. (That's because materializations are cached
1552 // in the conversion value mapping.) The insertion point of the
1553 // materialization must be valid for all future users that may be created
1554 // later in the conversion process.
1555 Value castValue =
1556 buildUnresolvedMaterialization(kind: MaterializationKind::Source,
1557 ip: computeInsertPoint(vals: repl), loc: value.getLoc(),
1558 /*valuesToMap=*/repl, /*inputs=*/repl,
1559 /*outputTypes=*/value.getType(),
1560 /*originalType=*/Type(), converter)
1561 .front();
1562 return castValue;
1563}
1564
1565//===----------------------------------------------------------------------===//
1566// Rewriter Notification Hooks
1567//===----------------------------------------------------------------------===//
1568
1569void ConversionPatternRewriterImpl::notifyOperationInserted(
1570 Operation *op, OpBuilder::InsertPoint previous) {
1571 LLVM_DEBUG({
1572 logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
1573 << ")\n";
1574 });
1575 assert(!wasOpReplaced(op->getParentOp()) &&
1576 "attempting to insert into a block within a replaced/erased op");
1577
1578 if (!previous.isSet()) {
1579 // This is a newly created op.
1580 appendRewrite<CreateOperationRewrite>(args&: op);
1581 patternNewOps.insert(X: op);
1582 return;
1583 }
1584 Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
1585 ? nullptr
1586 : &*previous.getPoint();
1587 appendRewrite<MoveOperationRewrite>(args&: op, args: previous.getBlock(), args&: prevOp);
1588}
1589
1590void ConversionPatternRewriterImpl::replaceOp(
1591 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1592 assert(newValues.size() == op->getNumResults());
1593 assert(!ignoredOps.contains(op) && "operation was already replaced");
1594
1595 // Check if replaced op is an unresolved materialization, i.e., an
1596 // unrealized_conversion_cast op that was created by the conversion driver.
1597 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(Val: op)) {
1598 // Make sure that the user does not mess with unresolved materializations
1599 // that were inserted by the conversion driver. We keep track of these
1600 // ops in internal data structures.
1601 assert(!unresolvedMaterializations.contains(castOp) &&
1602 "attempting to replace/erase an unresolved materialization");
1603 }
1604
1605 // Create mappings for each of the new result values.
1606 for (auto [repl, result] : llvm::zip_equal(t&: newValues, u: op->getResults())) {
1607 if (repl.empty()) {
1608 // This result was dropped and no replacement value was provided.
1609 // Materialize a replacement value "out of thin air".
1610 buildUnresolvedMaterialization(
1611 kind: MaterializationKind::Source, ip: computeInsertPoint(value: result),
1612 loc: result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
1613 /*outputTypes=*/result.getType(), /*originalType=*/Type(),
1614 converter: currentTypeConverter);
1615 continue;
1616 }
1617
1618 // Remap result to replacement value.
1619 if (repl.empty())
1620 continue;
1621 mapping.map(oldVal: static_cast<Value>(result), newVal: std::move(repl));
1622 }
1623
1624 appendRewrite<ReplaceOperationRewrite>(args&: op, args&: currentTypeConverter);
1625 // Mark this operation and all nested ops as replaced.
1626 op->walk(callback: [&](Operation *op) { replacedOps.insert(X: op); });
1627}
1628
1629void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
1630 BlockArgument from, ValueRange to, const TypeConverter *converter) {
1631 appendRewrite<ReplaceBlockArgRewrite>(args: from.getOwner(), args&: from, args&: converter);
1632 mapping.map(oldVal&: from, newVal&: to);
1633}
1634
1635void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
1636 assert(!wasOpReplaced(block->getParentOp()) &&
1637 "attempting to erase a block within a replaced/erased op");
1638 appendRewrite<EraseBlockRewrite>(args&: block);
1639
1640 // Unlink the block from its parent region. The block is kept in the rewrite
1641 // object and will be actually destroyed when rewrites are applied. This
1642 // allows us to keep the operations in the block live and undo the removal by
1643 // re-inserting the block.
1644 block->getParent()->getBlocks().remove(IT: block);
1645
1646 // Mark all nested ops as erased.
1647 block->walk(callback: [&](Operation *op) { replacedOps.insert(X: op); });
1648}
1649
1650void ConversionPatternRewriterImpl::notifyBlockInserted(
1651 Block *block, Region *previous, Region::iterator previousIt) {
1652 assert(!wasOpReplaced(block->getParentOp()) &&
1653 "attempting to insert into a region within a replaced/erased op");
1654 LLVM_DEBUG(
1655 {
1656 Operation *parent = block->getParentOp();
1657 if (parent) {
1658 logger.startLine() << "** Insert Block into : '" << parent->getName()
1659 << "'(" << parent << ")\n";
1660 } else {
1661 logger.startLine()
1662 << "** Insert Block into detached Region (nullptr parent op)'\n";
1663 }
1664 });
1665
1666 patternInsertedBlocks.insert(X: block);
1667
1668 if (!previous) {
1669 // This is a newly created block.
1670 appendRewrite<CreateBlockRewrite>(args&: block);
1671 return;
1672 }
1673 Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
1674 appendRewrite<MoveBlockRewrite>(args&: block, args&: previous, args&: prevBlock);
1675}
1676
1677void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source,
1678 Block *dest,
1679 Block::iterator before) {
1680 appendRewrite<InlineBlockRewrite>(args&: dest, args&: source, args&: before);
1681}
1682
1683void ConversionPatternRewriterImpl::notifyMatchFailure(
1684 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1685 LLVM_DEBUG({
1686 Diagnostic diag(loc, DiagnosticSeverity::Remark);
1687 reasonCallback(diag);
1688 logger.startLine() << "** Failure : " << diag.str() << "\n";
1689 if (config.notifyCallback)
1690 config.notifyCallback(diag);
1691 });
1692}
1693
1694//===----------------------------------------------------------------------===//
1695// ConversionPatternRewriter
1696//===----------------------------------------------------------------------===//
1697
1698ConversionPatternRewriter::ConversionPatternRewriter(
1699 MLIRContext *ctx, const ConversionConfig &config)
1700 : PatternRewriter(ctx),
1701 impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
1702 setListener(impl.get());
1703}
1704
1705ConversionPatternRewriter::~ConversionPatternRewriter() = default;
1706
1707void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
1708 assert(op && newOp && "expected non-null op");
1709 replaceOp(op, newValues: newOp->getResults());
1710}
1711
1712void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
1713 assert(op->getNumResults() == newValues.size() &&
1714 "incorrect # of replacement values");
1715 LLVM_DEBUG({
1716 impl->logger.startLine()
1717 << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1718 });
1719 SmallVector<SmallVector<Value>> newVals =
1720 llvm::map_to_vector(C&: newValues, F: [](Value v) -> SmallVector<Value> {
1721 return v ? SmallVector<Value>{v} : SmallVector<Value>();
1722 });
1723 impl->replaceOp(op, newValues: std::move(newVals));
1724}
1725
1726void ConversionPatternRewriter::replaceOpWithMultiple(
1727 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1728 assert(op->getNumResults() == newValues.size() &&
1729 "incorrect # of replacement values");
1730 LLVM_DEBUG({
1731 impl->logger.startLine()
1732 << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1733 });
1734 impl->replaceOp(op, newValues: std::move(newValues));
1735}
1736
1737void ConversionPatternRewriter::eraseOp(Operation *op) {
1738 LLVM_DEBUG({
1739 impl->logger.startLine()
1740 << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1741 });
1742 SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
1743 impl->replaceOp(op, newValues: std::move(nullRepls));
1744}
1745
1746void ConversionPatternRewriter::eraseBlock(Block *block) {
1747 impl->eraseBlock(block);
1748}
1749
1750Block *ConversionPatternRewriter::applySignatureConversion(
1751 Block *block, TypeConverter::SignatureConversion &conversion,
1752 const TypeConverter *converter) {
1753 assert(!impl->wasOpReplaced(block->getParentOp()) &&
1754 "attempting to apply a signature conversion to a block within a "
1755 "replaced/erased op");
1756 return impl->applySignatureConversion(rewriter&: *this, block, converter, signatureConversion&: conversion);
1757}
1758
1759FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
1760 Region *region, const TypeConverter &converter,
1761 TypeConverter::SignatureConversion *entryConversion) {
1762 assert(!impl->wasOpReplaced(region->getParentOp()) &&
1763 "attempting to apply a signature conversion to a block within a "
1764 "replaced/erased op");
1765 return impl->convertRegionTypes(rewriter&: *this, region, converter, entryConversion);
1766}
1767
1768void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
1769 ValueRange to) {
1770 LLVM_DEBUG({
1771 impl->logger.startLine() << "** Replace Argument : '" << from << "'";
1772 if (Operation *parentOp = from.getOwner()->getParentOp()) {
1773 impl->logger.getOStream() << " (in region of '" << parentOp->getName()
1774 << "' (" << parentOp << ")\n";
1775 } else {
1776 impl->logger.getOStream() << " (unlinked block)\n";
1777 }
1778 });
1779 impl->replaceUsesOfBlockArgument(from, to, converter: impl->currentTypeConverter);
1780}
1781
1782Value ConversionPatternRewriter::getRemappedValue(Value key) {
1783 SmallVector<ValueVector> remappedValues;
1784 if (failed(Result: impl->remapValues(valueDiagTag: "value", /*inputLoc=*/std::nullopt, rewriter&: *this, values: key,
1785 remapped&: remappedValues)))
1786 return nullptr;
1787 assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
1788 return remappedValues.front().front();
1789}
1790
1791LogicalResult
1792ConversionPatternRewriter::getRemappedValues(ValueRange keys,
1793 SmallVectorImpl<Value> &results) {
1794 if (keys.empty())
1795 return success();
1796 SmallVector<ValueVector> remapped;
1797 if (failed(Result: impl->remapValues(valueDiagTag: "value", /*inputLoc=*/std::nullopt, rewriter&: *this, values: keys,
1798 remapped)))
1799 return failure();
1800 for (const auto &values : remapped) {
1801 assert(values.size() == 1 && "1:N conversion not supported");
1802 results.push_back(Elt: values.front());
1803 }
1804 return success();
1805}
1806
1807void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
1808 Block::iterator before,
1809 ValueRange argValues) {
1810#ifndef NDEBUG
1811 assert(argValues.size() == source->getNumArguments() &&
1812 "incorrect # of argument replacement values");
1813 assert(!impl->wasOpReplaced(source->getParentOp()) &&
1814 "attempting to inline a block from a replaced/erased op");
1815 assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1816 "attempting to inline a block into a replaced/erased op");
1817 auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1818 // The source block will be deleted, so it should not have any users (i.e.,
1819 // there should be no predecessors).
1820 assert(llvm::all_of(source->getUsers(), opIgnored) &&
1821 "expected 'source' to have no predecessors");
1822#endif // NDEBUG
1823
1824 // If a listener is attached to the dialect conversion, ops cannot be moved
1825 // to the destination block in bulk ("fast path"). This is because at the time
1826 // the notifications are sent, it is unknown which ops were moved. Instead,
1827 // ops should be moved one-by-one ("slow path"), so that a separate
1828 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1829 // a bit more efficient, so we try to do that when possible.
1830 bool fastPath = !impl->config.listener;
1831
1832 if (fastPath)
1833 impl->inlineBlockBefore(source, dest, before);
1834
1835 // Replace all uses of block arguments.
1836 for (auto it : llvm::zip(t: source->getArguments(), u&: argValues))
1837 replaceUsesOfBlockArgument(from: std::get<0>(t&: it), to: std::get<1>(t&: it));
1838
1839 if (fastPath) {
1840 // Move all ops at once.
1841 dest->getOperations().splice(where: before, L2&: source->getOperations());
1842 } else {
1843 // Move op by op.
1844 while (!source->empty())
1845 moveOpBefore(op: &source->front(), block: dest, iterator: before);
1846 }
1847
1848 // Erase the source block.
1849 eraseBlock(block: source);
1850}
1851
1852void ConversionPatternRewriter::startOpModification(Operation *op) {
1853 assert(!impl->wasOpReplaced(op) &&
1854 "attempting to modify a replaced/erased op");
1855#ifndef NDEBUG
1856 impl->pendingRootUpdates.insert(op);
1857#endif
1858 impl->appendRewrite<ModifyOperationRewrite>(args&: op);
1859}
1860
1861void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
1862 assert(!impl->wasOpReplaced(op) &&
1863 "attempting to modify a replaced/erased op");
1864 PatternRewriter::finalizeOpModification(op);
1865 impl->patternModifiedOps.insert(X: op);
1866
1867 // There is nothing to do here, we only need to track the operation at the
1868 // start of the update.
1869#ifndef NDEBUG
1870 assert(impl->pendingRootUpdates.erase(op) &&
1871 "operation did not have a pending in-place update");
1872#endif
1873}
1874
1875void ConversionPatternRewriter::cancelOpModification(Operation *op) {
1876#ifndef NDEBUG
1877 assert(impl->pendingRootUpdates.erase(op) &&
1878 "operation did not have a pending in-place update");
1879#endif
1880 // Erase the last update for this operation.
1881 auto it = llvm::find_if(
1882 Range: llvm::reverse(C&: impl->rewrites), P: [&](std::unique_ptr<IRRewrite> &rewrite) {
1883 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(Val: rewrite.get());
1884 return modifyRewrite && modifyRewrite->getOperation() == op;
1885 });
1886 assert(it != impl->rewrites.rend() && "no root update started on op");
1887 (*it)->rollback();
1888 int updateIdx = std::prev(x: impl->rewrites.rend()) - it;
1889 impl->rewrites.erase(CI: impl->rewrites.begin() + updateIdx);
1890}
1891
1892detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
1893 return *impl;
1894}
1895
1896//===----------------------------------------------------------------------===//
1897// ConversionPattern
1898//===----------------------------------------------------------------------===//
1899
1900SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands(
1901 ArrayRef<ValueRange> operands) const {
1902 SmallVector<Value> oneToOneOperands;
1903 oneToOneOperands.reserve(N: operands.size());
1904 for (ValueRange operand : operands) {
1905 if (operand.size() != 1)
1906 llvm::report_fatal_error(reason: "pattern '" + getDebugName() +
1907 "' does not support 1:N conversion");
1908 oneToOneOperands.push_back(Elt: operand.front());
1909 }
1910 return oneToOneOperands;
1911}
1912
1913LogicalResult
1914ConversionPattern::matchAndRewrite(Operation *op,
1915 PatternRewriter &rewriter) const {
1916 auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1917 auto &rewriterImpl = dialectRewriter.getImpl();
1918
1919 // Track the current conversion pattern type converter in the rewriter.
1920 llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1921 getTypeConverter());
1922
1923 // Remap the operands of the operation.
1924 SmallVector<ValueVector> remapped;
1925 if (failed(Result: rewriterImpl.remapValues(valueDiagTag: "operand", inputLoc: op->getLoc(), rewriter,
1926 values: op->getOperands(), remapped))) {
1927 return failure();
1928 }
1929 SmallVector<ValueRange> remappedAsRange =
1930 llvm::to_vector_of<ValueRange>(Range&: remapped);
1931 return matchAndRewrite(op, operands: remappedAsRange, rewriter&: dialectRewriter);
1932}
1933
1934//===----------------------------------------------------------------------===//
1935// OperationLegalizer
1936//===----------------------------------------------------------------------===//
1937
1938namespace {
1939/// A set of rewrite patterns that can be used to legalize a given operation.
1940using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1941
1942/// This class defines a recursive operation legalizer.
1943class OperationLegalizer {
1944public:
1945 using LegalizationAction = ConversionTarget::LegalizationAction;
1946
1947 OperationLegalizer(const ConversionTarget &targetInfo,
1948 const FrozenRewritePatternSet &patterns,
1949 const ConversionConfig &config);
1950
1951 /// Returns true if the given operation is known to be illegal on the target.
1952 bool isIllegal(Operation *op) const;
1953
1954 /// Attempt to legalize the given operation. Returns success if the operation
1955 /// was legalized, failure otherwise.
1956 LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1957
1958 /// Returns the conversion target in use by the legalizer.
1959 const ConversionTarget &getTarget() { return target; }
1960
1961private:
1962 /// Attempt to legalize the given operation by folding it.
1963 LogicalResult legalizeWithFold(Operation *op,
1964 ConversionPatternRewriter &rewriter);
1965
1966 /// Attempt to legalize the given operation by applying a pattern. Returns
1967 /// success if the operation was legalized, failure otherwise.
1968 LogicalResult legalizeWithPattern(Operation *op,
1969 ConversionPatternRewriter &rewriter);
1970
1971 /// Return true if the given pattern may be applied to the given operation,
1972 /// false otherwise.
1973 bool canApplyPattern(Operation *op, const Pattern &pattern,
1974 ConversionPatternRewriter &rewriter);
1975
1976 /// Legalize the resultant IR after successfully applying the given pattern.
1977 LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1978 ConversionPatternRewriter &rewriter,
1979 const SetVector<Operation *> &newOps,
1980 const SetVector<Operation *> &modifiedOps,
1981 const SetVector<Block *> &insertedBlocks);
1982
1983 /// Legalizes the actions registered during the execution of a pattern.
1984 LogicalResult
1985 legalizePatternBlockRewrites(Operation *op,
1986 ConversionPatternRewriter &rewriter,
1987 ConversionPatternRewriterImpl &impl,
1988 const SetVector<Block *> &insertedBlocks,
1989 const SetVector<Operation *> &newOps);
1990 LogicalResult
1991 legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter,
1992 ConversionPatternRewriterImpl &impl,
1993 const SetVector<Operation *> &newOps);
1994 LogicalResult
1995 legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1996 ConversionPatternRewriterImpl &impl,
1997 const SetVector<Operation *> &modifiedOps);
1998
1999 //===--------------------------------------------------------------------===//
2000 // Cost Model
2001 //===--------------------------------------------------------------------===//
2002
2003 /// Build an optimistic legalization graph given the provided patterns. This
2004 /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
2005 /// patterns for operations that are not directly legal, but may be
2006 /// transitively legal for the current target given the provided patterns.
2007 void buildLegalizationGraph(
2008 LegalizationPatterns &anyOpLegalizerPatterns,
2009 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
2010
2011 /// Compute the benefit of each node within the computed legalization graph.
2012 /// This orders the patterns within 'legalizerPatterns' based upon two
2013 /// criteria:
2014 /// 1) Prefer patterns that have the lowest legalization depth, i.e.
2015 /// represent the more direct mapping to the target.
2016 /// 2) When comparing patterns with the same legalization depth, prefer the
2017 /// pattern with the highest PatternBenefit. This allows for users to
2018 /// prefer specific legalizations over others.
2019 void computeLegalizationGraphBenefit(
2020 LegalizationPatterns &anyOpLegalizerPatterns,
2021 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
2022
2023 /// Compute the legalization depth when legalizing an operation of the given
2024 /// type.
2025 unsigned computeOpLegalizationDepth(
2026 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2027 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
2028
2029 /// Apply the conversion cost model to the given set of patterns, and return
2030 /// the smallest legalization depth of any of the patterns. See
2031 /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
2032 unsigned applyCostModelToPatterns(
2033 LegalizationPatterns &patterns,
2034 DenseMap<OperationName, unsigned> &minOpPatternDepth,
2035 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
2036
2037 /// The current set of patterns that have been applied.
2038 SmallPtrSet<const Pattern *, 8> appliedPatterns;
2039
2040 /// The legalization information provided by the target.
2041 const ConversionTarget &target;
2042
2043 /// The pattern applicator to use for conversions.
2044 PatternApplicator applicator;
2045
2046 /// Dialect conversion configuration.
2047 const ConversionConfig &config;
2048};
2049} // namespace
2050
2051OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
2052 const FrozenRewritePatternSet &patterns,
2053 const ConversionConfig &config)
2054 : target(targetInfo), applicator(patterns), config(config) {
2055 // The set of patterns that can be applied to illegal operations to transform
2056 // them into legal ones.
2057 DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
2058 LegalizationPatterns anyOpLegalizerPatterns;
2059
2060 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2061 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2062}
2063
2064bool OperationLegalizer::isIllegal(Operation *op) const {
2065 return target.isIllegal(op);
2066}
2067
2068LogicalResult
2069OperationLegalizer::legalize(Operation *op,
2070 ConversionPatternRewriter &rewriter) {
2071#ifndef NDEBUG
2072 const char *logLineComment =
2073 "//===-------------------------------------------===//\n";
2074
2075 auto &logger = rewriter.getImpl().logger;
2076#endif
2077
2078 // Check to see if the operation is ignored and doesn't need to be converted.
2079 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2080
2081 LLVM_DEBUG({
2082 logger.getOStream() << "\n";
2083 logger.startLine() << logLineComment;
2084 logger.startLine() << "Legalizing operation : ";
2085 // Do not print the operation name if the operation is ignored. Ignored ops
2086 // may have been erased and should not be accessed. The pointer can be
2087 // printed safely.
2088 if (!isIgnored)
2089 logger.getOStream() << "'" << op->getName() << "' ";
2090 logger.getOStream() << "(" << op << ") {\n";
2091 logger.indent();
2092
2093 // If the operation has no regions, just print it here.
2094 if (!isIgnored && op->getNumRegions() == 0) {
2095 op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
2096 logger.getOStream() << "\n\n";
2097 }
2098 });
2099
2100 if (isIgnored) {
2101 LLVM_DEBUG({
2102 logSuccess(logger, "operation marked 'ignored' during conversion");
2103 logger.startLine() << logLineComment;
2104 });
2105 return success();
2106 }
2107
2108 // Check if this operation is legal on the target.
2109 if (auto legalityInfo = target.isLegal(op)) {
2110 LLVM_DEBUG({
2111 logSuccess(
2112 logger, "operation marked legal by the target{0}",
2113 legalityInfo->isRecursivelyLegal
2114 ? "; NOTE: operation is recursively legal; skipping internals"
2115 : "");
2116 logger.startLine() << logLineComment;
2117 });
2118
2119 // If this operation is recursively legal, mark its children as ignored so
2120 // that we don't consider them for legalization.
2121 if (legalityInfo->isRecursivelyLegal) {
2122 op->walk(callback: [&](Operation *nested) {
2123 if (op != nested)
2124 rewriter.getImpl().ignoredOps.insert(X: nested);
2125 });
2126 }
2127
2128 return success();
2129 }
2130
2131 // If the operation isn't legal, try to fold it in-place.
2132 // TODO: Should we always try to do this, even if the op is
2133 // already legal?
2134 if (succeeded(Result: legalizeWithFold(op, rewriter))) {
2135 LLVM_DEBUG({
2136 logSuccess(logger, "operation was folded");
2137 logger.startLine() << logLineComment;
2138 });
2139 return success();
2140 }
2141
2142 // Otherwise, we need to apply a legalization pattern to this operation.
2143 if (succeeded(Result: legalizeWithPattern(op, rewriter))) {
2144 LLVM_DEBUG({
2145 logSuccess(logger, "");
2146 logger.startLine() << logLineComment;
2147 });
2148 return success();
2149 }
2150
2151 LLVM_DEBUG({
2152 logFailure(logger, "no matched legalization pattern");
2153 logger.startLine() << logLineComment;
2154 });
2155 return failure();
2156}
2157
2158/// Helper function that moves and returns the given object. Also resets the
2159/// original object, so that it is in a valid, empty state again.
2160template <typename T>
2161static T moveAndReset(T &obj) {
2162 T result = std::move(obj);
2163 obj = T();
2164 return result;
2165}
2166
2167LogicalResult
2168OperationLegalizer::legalizeWithFold(Operation *op,
2169 ConversionPatternRewriter &rewriter) {
2170 auto &rewriterImpl = rewriter.getImpl();
2171 LLVM_DEBUG({
2172 rewriterImpl.logger.startLine() << "* Fold {\n";
2173 rewriterImpl.logger.indent();
2174 });
2175 (void)rewriterImpl;
2176
2177 // Try to fold the operation.
2178 StringRef opName = op->getName().getStringRef();
2179 SmallVector<Value, 2> replacementValues;
2180 SmallVector<Operation *, 2> newOps;
2181 rewriter.setInsertionPoint(op);
2182 if (failed(Result: rewriter.tryFold(op, results&: replacementValues, materializedConstants: &newOps))) {
2183 LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2184 return failure();
2185 }
2186
2187 // An empty list of replacement values indicates that the fold was in-place.
2188 // As the operation changed, a new legalization needs to be attempted.
2189 if (replacementValues.empty())
2190 return legalize(op, rewriter);
2191
2192 // Recursively legalize any new constant operations.
2193 for (Operation *newOp : newOps) {
2194 if (failed(Result: legalize(op: newOp, rewriter))) {
2195 LLVM_DEBUG(logFailure(rewriterImpl.logger,
2196 "failed to legalize generated constant '{0}'",
2197 newOp->getName()));
2198 if (!config.allowPatternRollback) {
2199 // Rolling back a folder is like rolling back a pattern.
2200 llvm::report_fatal_error(
2201 reason: "op '" + opName +
2202 "' folder rollback of IR modifications requested");
2203 }
2204 // Legalization failed: erase all materialized constants.
2205 for (Operation *op : newOps)
2206 rewriter.eraseOp(op);
2207 return failure();
2208 }
2209 }
2210
2211 // Insert a replacement for 'op' with the folded replacement values.
2212 rewriter.replaceOp(op, newValues: replacementValues);
2213
2214 LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2215 return success();
2216}
2217
2218LogicalResult
2219OperationLegalizer::legalizeWithPattern(Operation *op,
2220 ConversionPatternRewriter &rewriter) {
2221 auto &rewriterImpl = rewriter.getImpl();
2222
2223 // Functor that returns if the given pattern may be applied.
2224 auto canApply = [&](const Pattern &pattern) {
2225 bool canApply = canApplyPattern(op, pattern, rewriter);
2226 if (canApply && config.listener)
2227 config.listener->notifyPatternBegin(pattern, op);
2228 return canApply;
2229 };
2230
2231 // Functor that cleans up the rewriter state after a pattern failed to match.
2232 RewriterState curState = rewriterImpl.getCurrentState();
2233 auto onFailure = [&](const Pattern &pattern) {
2234 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2235 rewriterImpl.patternNewOps.clear();
2236 rewriterImpl.patternModifiedOps.clear();
2237 rewriterImpl.patternInsertedBlocks.clear();
2238 LLVM_DEBUG({
2239 logFailure(rewriterImpl.logger, "pattern failed to match");
2240 if (rewriterImpl.config.notifyCallback) {
2241 Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
2242 diag << "Failed to apply pattern \"" << pattern.getDebugName()
2243 << "\" on op:\n"
2244 << *op;
2245 rewriterImpl.config.notifyCallback(diag);
2246 }
2247 });
2248 if (config.listener)
2249 config.listener->notifyPatternEnd(pattern, status: failure());
2250 rewriterImpl.resetState(state: curState, patternName: pattern.getDebugName());
2251 appliedPatterns.erase(Ptr: &pattern);
2252 };
2253
2254 // Functor that performs additional legalization when a pattern is
2255 // successfully applied.
2256 auto onSuccess = [&](const Pattern &pattern) {
2257 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2258 SetVector<Operation *> newOps = moveAndReset(obj&: rewriterImpl.patternNewOps);
2259 SetVector<Operation *> modifiedOps =
2260 moveAndReset(obj&: rewriterImpl.patternModifiedOps);
2261 SetVector<Block *> insertedBlocks =
2262 moveAndReset(obj&: rewriterImpl.patternInsertedBlocks);
2263 auto result = legalizePatternResult(op, pattern, rewriter, newOps,
2264 modifiedOps, insertedBlocks);
2265 appliedPatterns.erase(Ptr: &pattern);
2266 if (failed(Result: result)) {
2267 if (!rewriterImpl.config.allowPatternRollback)
2268 llvm::report_fatal_error(reason: "pattern '" + pattern.getDebugName() +
2269 "' produced IR that could not be legalized");
2270 rewriterImpl.resetState(state: curState, patternName: pattern.getDebugName());
2271 }
2272 if (config.listener)
2273 config.listener->notifyPatternEnd(pattern, status: result);
2274 return result;
2275 };
2276
2277 // Try to match and rewrite a pattern on this operation.
2278 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2279 onSuccess);
2280}
2281
2282bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
2283 ConversionPatternRewriter &rewriter) {
2284 LLVM_DEBUG({
2285 auto &os = rewriter.getImpl().logger;
2286 os.getOStream() << "\n";
2287 os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2288 llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2289 os.getOStream() << ")' {\n";
2290 os.indent();
2291 });
2292
2293 // Ensure that we don't cycle by not allowing the same pattern to be
2294 // applied twice in the same recursion stack if it is not known to be safe.
2295 if (!pattern.hasBoundedRewriteRecursion() &&
2296 !appliedPatterns.insert(Ptr: &pattern).second) {
2297 LLVM_DEBUG(
2298 logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2299 return false;
2300 }
2301 return true;
2302}
2303
2304LogicalResult OperationLegalizer::legalizePatternResult(
2305 Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
2306 const SetVector<Operation *> &newOps,
2307 const SetVector<Operation *> &modifiedOps,
2308 const SetVector<Block *> &insertedBlocks) {
2309 auto &impl = rewriter.getImpl();
2310 assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2311
2312#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2313 // Check that the root was either replaced or updated in place.
2314 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2315 auto replacedRoot = [&] {
2316 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2317 };
2318 auto updatedRootInPlace = [&] {
2319 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2320 };
2321 if (!replacedRoot() && !updatedRootInPlace())
2322 llvm::report_fatal_error("expected pattern to replace the root operation");
2323#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2324
2325 // Legalize each of the actions registered during application.
2326 if (failed(Result: legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks,
2327 newOps)) ||
2328 failed(Result: legalizePatternRootUpdates(rewriter, impl, modifiedOps)) ||
2329 failed(Result: legalizePatternCreatedOperations(rewriter, impl, newOps))) {
2330 return failure();
2331 }
2332
2333 LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2334 return success();
2335}
2336
2337LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2338 Operation *op, ConversionPatternRewriter &rewriter,
2339 ConversionPatternRewriterImpl &impl,
2340 const SetVector<Block *> &insertedBlocks,
2341 const SetVector<Operation *> &newOps) {
2342 SmallPtrSet<Operation *, 16> alreadyLegalized;
2343
2344 // If the pattern moved or created any blocks, make sure the types of block
2345 // arguments get legalized.
2346 for (Block *block : insertedBlocks) {
2347 // Only check blocks outside of the current operation.
2348 Operation *parentOp = block->getParentOp();
2349 if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2350 continue;
2351
2352 // If the region of the block has a type converter, try to convert the block
2353 // directly.
2354 if (auto *converter = impl.regionToConverter.lookup(Val: block->getParent())) {
2355 std::optional<TypeConverter::SignatureConversion> conversion =
2356 converter->convertBlockSignature(block);
2357 if (!conversion) {
2358 LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2359 "block"));
2360 return failure();
2361 }
2362 impl.applySignatureConversion(rewriter, block, converter, signatureConversion&: *conversion);
2363 continue;
2364 }
2365
2366 // Otherwise, try to legalize the parent operation if it was not generated
2367 // by this pattern. This is because we will attempt to legalize the parent
2368 // operation, and blocks in regions created by this pattern will already be
2369 // legalized later on.
2370 if (!newOps.count(key: parentOp) && alreadyLegalized.insert(Ptr: parentOp).second) {
2371 if (failed(Result: legalize(op: parentOp, rewriter))) {
2372 LLVM_DEBUG(logFailure(
2373 impl.logger, "operation '{0}'({1}) became illegal after rewrite",
2374 parentOp->getName(), parentOp));
2375 return failure();
2376 }
2377 }
2378 }
2379 return success();
2380}
2381
2382LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2383 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
2384 const SetVector<Operation *> &newOps) {
2385 for (Operation *op : newOps) {
2386 if (failed(Result: legalize(op, rewriter))) {
2387 LLVM_DEBUG(logFailure(impl.logger,
2388 "failed to legalize generated operation '{0}'({1})",
2389 op->getName(), op));
2390 return failure();
2391 }
2392 }
2393 return success();
2394}
2395
2396LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2397 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
2398 const SetVector<Operation *> &modifiedOps) {
2399 for (Operation *op : modifiedOps) {
2400 if (failed(Result: legalize(op, rewriter))) {
2401 LLVM_DEBUG(logFailure(
2402 impl.logger, "failed to legalize operation updated in-place '{0}'",
2403 op->getName()));
2404 return failure();
2405 }
2406 }
2407 return success();
2408}
2409
2410//===----------------------------------------------------------------------===//
2411// Cost Model
2412//===----------------------------------------------------------------------===//
2413
2414void OperationLegalizer::buildLegalizationGraph(
2415 LegalizationPatterns &anyOpLegalizerPatterns,
2416 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2417 // A mapping between an operation and a set of operations that can be used to
2418 // generate it.
2419 DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps;
2420 // A mapping between an operation and any currently invalid patterns it has.
2421 DenseMap<OperationName, SmallPtrSet<const Pattern *, 2>> invalidPatterns;
2422 // A worklist of patterns to consider for legality.
2423 SetVector<const Pattern *> patternWorklist;
2424
2425 // Build the mapping from operations to the parent ops that may generate them.
2426 applicator.walkAllPatterns(walk: [&](const Pattern &pattern) {
2427 std::optional<OperationName> root = pattern.getRootKind();
2428
2429 // If the pattern has no specific root, we can't analyze the relationship
2430 // between the root op and generated operations. Given that, add all such
2431 // patterns to the legalization set.
2432 if (!root) {
2433 anyOpLegalizerPatterns.push_back(Elt: &pattern);
2434 return;
2435 }
2436
2437 // Skip operations that are always known to be legal.
2438 if (target.getOpAction(op: *root) == LegalizationAction::Legal)
2439 return;
2440
2441 // Add this pattern to the invalid set for the root op and record this root
2442 // as a parent for any generated operations.
2443 invalidPatterns[*root].insert(Ptr: &pattern);
2444 for (auto op : pattern.getGeneratedOps())
2445 parentOps[op].insert(Ptr: *root);
2446
2447 // Add this pattern to the worklist.
2448 patternWorklist.insert(X: &pattern);
2449 });
2450
2451 // If there are any patterns that don't have a specific root kind, we can't
2452 // make direct assumptions about what operations will never be legalized.
2453 // Note: Technically we could, but it would require an analysis that may
2454 // recurse into itself. It would be better to perform this kind of filtering
2455 // at a higher level than here anyways.
2456 if (!anyOpLegalizerPatterns.empty()) {
2457 for (const Pattern *pattern : patternWorklist)
2458 legalizerPatterns[*pattern->getRootKind()].push_back(Elt: pattern);
2459 return;
2460 }
2461
2462 while (!patternWorklist.empty()) {
2463 auto *pattern = patternWorklist.pop_back_val();
2464
2465 // Check to see if any of the generated operations are invalid.
2466 if (llvm::any_of(Range: pattern->getGeneratedOps(), P: [&](OperationName op) {
2467 std::optional<LegalizationAction> action = target.getOpAction(op);
2468 return !legalizerPatterns.count(Val: op) &&
2469 (!action || action == LegalizationAction::Illegal);
2470 }))
2471 continue;
2472
2473 // Otherwise, if all of the generated operation are valid, this op is now
2474 // legal so add all of the child patterns to the worklist.
2475 legalizerPatterns[*pattern->getRootKind()].push_back(Elt: pattern);
2476 invalidPatterns[*pattern->getRootKind()].erase(Ptr: pattern);
2477
2478 // Add any invalid patterns of the parent operations to see if they have now
2479 // become legal.
2480 for (auto op : parentOps[*pattern->getRootKind()])
2481 patternWorklist.set_union(invalidPatterns[op]);
2482 }
2483}
2484
2485void OperationLegalizer::computeLegalizationGraphBenefit(
2486 LegalizationPatterns &anyOpLegalizerPatterns,
2487 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2488 // The smallest pattern depth, when legalizing an operation.
2489 DenseMap<OperationName, unsigned> minOpPatternDepth;
2490
2491 // For each operation that is transitively legal, compute a cost for it.
2492 for (auto &opIt : legalizerPatterns)
2493 if (!minOpPatternDepth.count(Val: opIt.first))
2494 computeOpLegalizationDepth(op: opIt.first, minOpPatternDepth,
2495 legalizerPatterns);
2496
2497 // Apply the cost model to the patterns that can match any operation. Those
2498 // with a specific operation type are already resolved when computing the op
2499 // legalization depth.
2500 if (!anyOpLegalizerPatterns.empty())
2501 applyCostModelToPatterns(patterns&: anyOpLegalizerPatterns, minOpPatternDepth,
2502 legalizerPatterns);
2503
2504 // Apply a cost model to the pattern applicator. We order patterns first by
2505 // depth then benefit. `legalizerPatterns` contains per-op patterns by
2506 // decreasing benefit.
2507 applicator.applyCostModel(model: [&](const Pattern &pattern) {
2508 ArrayRef<const Pattern *> orderedPatternList;
2509 if (std::optional<OperationName> rootName = pattern.getRootKind())
2510 orderedPatternList = legalizerPatterns[*rootName];
2511 else
2512 orderedPatternList = anyOpLegalizerPatterns;
2513
2514 // If the pattern is not found, then it was removed and cannot be matched.
2515 auto *it = llvm::find(Range&: orderedPatternList, Val: &pattern);
2516 if (it == orderedPatternList.end())
2517 return PatternBenefit::impossibleToMatch();
2518
2519 // Patterns found earlier in the list have higher benefit.
2520 return PatternBenefit(std::distance(first: it, last: orderedPatternList.end()));
2521 });
2522}
2523
2524unsigned OperationLegalizer::computeOpLegalizationDepth(
2525 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2526 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2527 // Check for existing depth.
2528 auto depthIt = minOpPatternDepth.find(Val: op);
2529 if (depthIt != minOpPatternDepth.end())
2530 return depthIt->second;
2531
2532 // If a mapping for this operation does not exist, then this operation
2533 // is always legal. Return 0 as the depth for a directly legal operation.
2534 auto opPatternsIt = legalizerPatterns.find(Val: op);
2535 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2536 return 0u;
2537
2538 // Record this initial depth in case we encounter this op again when
2539 // recursively computing the depth.
2540 minOpPatternDepth.try_emplace(Key: op, Args: std::numeric_limits<unsigned>::max());
2541
2542 // Apply the cost model to the operation patterns, and update the minimum
2543 // depth.
2544 unsigned minDepth = applyCostModelToPatterns(
2545 patterns&: opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2546 minOpPatternDepth[op] = minDepth;
2547 return minDepth;
2548}
2549
2550unsigned OperationLegalizer::applyCostModelToPatterns(
2551 LegalizationPatterns &patterns,
2552 DenseMap<OperationName, unsigned> &minOpPatternDepth,
2553 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2554 unsigned minDepth = std::numeric_limits<unsigned>::max();
2555
2556 // Compute the depth for each pattern within the set.
2557 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2558 patternsByDepth.reserve(N: patterns.size());
2559 for (const Pattern *pattern : patterns) {
2560 unsigned depth = 1;
2561 for (auto generatedOp : pattern->getGeneratedOps()) {
2562 unsigned generatedOpDepth = computeOpLegalizationDepth(
2563 op: generatedOp, minOpPatternDepth, legalizerPatterns);
2564 depth = std::max(a: depth, b: generatedOpDepth + 1);
2565 }
2566 patternsByDepth.emplace_back(Args&: pattern, Args&: depth);
2567
2568 // Update the minimum depth of the pattern list.
2569 minDepth = std::min(a: minDepth, b: depth);
2570 }
2571
2572 // If the operation only has one legalization pattern, there is no need to
2573 // sort them.
2574 if (patternsByDepth.size() == 1)
2575 return minDepth;
2576
2577 // Sort the patterns by those likely to be the most beneficial.
2578 llvm::stable_sort(Range&: patternsByDepth,
2579 C: [](const std::pair<const Pattern *, unsigned> &lhs,
2580 const std::pair<const Pattern *, unsigned> &rhs) {
2581 // First sort by the smaller pattern legalization
2582 // depth.
2583 if (lhs.second != rhs.second)
2584 return lhs.second < rhs.second;
2585
2586 // Then sort by the larger pattern benefit.
2587 auto lhsBenefit = lhs.first->getBenefit();
2588 auto rhsBenefit = rhs.first->getBenefit();
2589 return lhsBenefit > rhsBenefit;
2590 });
2591
2592 // Update the legalization pattern to use the new sorted list.
2593 patterns.clear();
2594 for (auto &patternIt : patternsByDepth)
2595 patterns.push_back(Elt: patternIt.first);
2596 return minDepth;
2597}
2598
2599//===----------------------------------------------------------------------===//
2600// OperationConverter
2601//===----------------------------------------------------------------------===//
2602namespace {
2603enum OpConversionMode {
2604 /// In this mode, the conversion will ignore failed conversions to allow
2605 /// illegal operations to co-exist in the IR.
2606 Partial,
2607
2608 /// In this mode, all operations must be legal for the given target for the
2609 /// conversion to succeed.
2610 Full,
2611
2612 /// In this mode, operations are analyzed for legality. No actual rewrites are
2613 /// applied to the operations on success.
2614 Analysis,
2615};
2616} // namespace
2617
2618namespace mlir {
2619// This class converts operations to a given conversion target via a set of
2620// rewrite patterns. The conversion behaves differently depending on the
2621// conversion mode.
2622struct OperationConverter {
2623 explicit OperationConverter(const ConversionTarget &target,
2624 const FrozenRewritePatternSet &patterns,
2625 const ConversionConfig &config,
2626 OpConversionMode mode)
2627 : config(config), opLegalizer(target, patterns, this->config),
2628 mode(mode) {}
2629
2630 /// Converts the given operations to the conversion target.
2631 LogicalResult convertOperations(ArrayRef<Operation *> ops);
2632
2633private:
2634 /// Converts an operation with the given rewriter.
2635 LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2636
2637 /// Dialect conversion configuration.
2638 ConversionConfig config;
2639
2640 /// The legalizer to use when converting operations.
2641 OperationLegalizer opLegalizer;
2642
2643 /// The conversion mode to use when legalizing operations.
2644 OpConversionMode mode;
2645};
2646} // namespace mlir
2647
2648LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2649 Operation *op) {
2650 // Legalize the given operation.
2651 if (failed(Result: opLegalizer.legalize(op, rewriter))) {
2652 // Handle the case of a failed conversion for each of the different modes.
2653 // Full conversions expect all operations to be converted.
2654 if (mode == OpConversionMode::Full)
2655 return op->emitError()
2656 << "failed to legalize operation '" << op->getName() << "'";
2657 // Partial conversions allow conversions to fail iff the operation was not
2658 // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2659 // set, non-legalizable ops are added to that set.
2660 if (mode == OpConversionMode::Partial) {
2661 if (opLegalizer.isIllegal(op))
2662 return op->emitError()
2663 << "failed to legalize operation '" << op->getName()
2664 << "' that was explicitly marked illegal";
2665 if (config.unlegalizedOps)
2666 config.unlegalizedOps->insert(V: op);
2667 }
2668 } else if (mode == OpConversionMode::Analysis) {
2669 // Analysis conversions don't fail if any operations fail to legalize,
2670 // they are only interested in the operations that were successfully
2671 // legalized.
2672 if (config.legalizableOps)
2673 config.legalizableOps->insert(V: op);
2674 }
2675 return success();
2676}
2677
2678static LogicalResult
2679legalizeUnresolvedMaterialization(RewriterBase &rewriter,
2680 UnrealizedConversionCastOp op,
2681 const UnresolvedMaterializationInfo &info) {
2682 assert(!op.use_empty() &&
2683 "expected that dead materializations have already been DCE'd");
2684 Operation::operand_range inputOperands = op.getOperands();
2685
2686 // Try to materialize the conversion.
2687 if (const TypeConverter *converter = info.getConverter()) {
2688 rewriter.setInsertionPoint(op);
2689 SmallVector<Value> newMaterialization;
2690 switch (info.getMaterializationKind()) {
2691 case MaterializationKind::Target:
2692 newMaterialization = converter->materializeTargetConversion(
2693 builder&: rewriter, loc: op->getLoc(), resultType: op.getResultTypes(), inputs: inputOperands,
2694 originalType: info.getOriginalType());
2695 break;
2696 case MaterializationKind::Source:
2697 assert(op->getNumResults() == 1 && "expected single result");
2698 Value sourceMat = converter->materializeSourceConversion(
2699 builder&: rewriter, loc: op->getLoc(), resultType: op.getResultTypes().front(), inputs: inputOperands);
2700 if (sourceMat)
2701 newMaterialization.push_back(Elt: sourceMat);
2702 break;
2703 }
2704 if (!newMaterialization.empty()) {
2705#ifndef NDEBUG
2706 ValueRange newMaterializationRange(newMaterialization);
2707 assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
2708 "materialization callback produced value of incorrect type");
2709#endif // NDEBUG
2710 rewriter.replaceOp(op, newValues: newMaterialization);
2711 return success();
2712 }
2713 }
2714
2715 InFlightDiagnostic diag = op->emitError()
2716 << "failed to legalize unresolved materialization "
2717 "from ("
2718 << inputOperands.getTypes() << ") to ("
2719 << op.getResultTypes()
2720 << ") that remained live after conversion";
2721 diag.attachNote(noteLoc: op->getUsers().begin()->getLoc())
2722 << "see existing live user here: " << *op->getUsers().begin();
2723 return failure();
2724}
2725
2726LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2727 assert(!ops.empty() && "expected at least one operation");
2728 const ConversionTarget &target = opLegalizer.getTarget();
2729
2730 // Compute the set of operations and blocks to convert.
2731 SmallVector<Operation *> toConvert;
2732 for (auto *op : ops) {
2733 op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
2734 callback: [&](Operation *op) {
2735 toConvert.push_back(Elt: op);
2736 // Don't check this operation's children for conversion if the
2737 // operation is recursively legal.
2738 auto legalityInfo = target.isLegal(op);
2739 if (legalityInfo && legalityInfo->isRecursivelyLegal)
2740 return WalkResult::skip();
2741 return WalkResult::advance();
2742 });
2743 }
2744
2745 // Convert each operation and discard rewrites on failure.
2746 ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
2747 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2748
2749 for (auto *op : toConvert) {
2750 if (failed(Result: convert(rewriter, op))) {
2751 // Dialect conversion failed.
2752 if (rewriterImpl.config.allowPatternRollback) {
2753 // Rollback is allowed: restore the original IR.
2754 rewriterImpl.undoRewrites();
2755 } else {
2756 // Rollback is not allowed: apply all modifications that have been
2757 // performed so far.
2758 rewriterImpl.applyRewrites();
2759 }
2760 return failure();
2761 }
2762 }
2763
2764 // After a successful conversion, apply rewrites.
2765 rewriterImpl.applyRewrites();
2766
2767 // Gather all unresolved materializations.
2768 SmallVector<UnrealizedConversionCastOp> allCastOps;
2769 const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
2770 &materializations = rewriterImpl.unresolvedMaterializations;
2771 for (auto it : materializations)
2772 allCastOps.push_back(Elt: it.first);
2773
2774 // Reconcile all UnrealizedConversionCastOps that were inserted by the
2775 // dialect conversion frameworks. (Not the one that were inserted by
2776 // patterns.)
2777 SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2778 reconcileUnrealizedCasts(castOps: allCastOps, remainingCastOps: &remainingCastOps);
2779
2780 // Try to legalize all unresolved materializations.
2781 if (config.buildMaterializations) {
2782 IRRewriter rewriter(rewriterImpl.context, config.listener);
2783 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2784 auto it = materializations.find(Val: castOp);
2785 assert(it != materializations.end() && "inconsistent state");
2786 if (failed(
2787 Result: legalizeUnresolvedMaterialization(rewriter, op: castOp, info: it->second)))
2788 return failure();
2789 }
2790 }
2791
2792 return success();
2793}
2794
2795//===----------------------------------------------------------------------===//
2796// Reconcile Unrealized Casts
2797//===----------------------------------------------------------------------===//
2798
2799void mlir::reconcileUnrealizedCasts(
2800 ArrayRef<UnrealizedConversionCastOp> castOps,
2801 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2802 SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
2803 // This set is maintained only if `remainingCastOps` is provided.
2804 DenseSet<Operation *> erasedOps;
2805
2806 // Helper function that adds all operands to the worklist that are an
2807 // unrealized_conversion_cast op result.
2808 auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2809 for (Value v : castOp.getInputs())
2810 if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2811 worklist.insert(X: inputCastOp);
2812 };
2813
2814 // Helper function that return the unrealized_conversion_cast op that
2815 // defines all inputs of the given op (in the same order). Return "nullptr"
2816 // if there is no such op.
2817 auto getInputCast =
2818 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2819 if (castOp.getInputs().empty())
2820 return {};
2821 auto inputCastOp =
2822 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2823 if (!inputCastOp)
2824 return {};
2825 if (inputCastOp.getOutputs() != castOp.getInputs())
2826 return {};
2827 return inputCastOp;
2828 };
2829
2830 // Process ops in the worklist bottom-to-top.
2831 while (!worklist.empty()) {
2832 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2833 if (castOp->use_empty()) {
2834 // DCE: If the op has no users, erase it. Add the operands to the
2835 // worklist to find additional DCE opportunities.
2836 enqueueOperands(castOp);
2837 if (remainingCastOps)
2838 erasedOps.insert(V: castOp.getOperation());
2839 castOp->erase();
2840 continue;
2841 }
2842
2843 // Traverse the chain of input cast ops to see if an op with the same
2844 // input types can be found.
2845 UnrealizedConversionCastOp nextCast = castOp;
2846 while (nextCast) {
2847 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2848 // Found a cast where the input types match the output types of the
2849 // matched op. We can directly use those inputs and the matched op can
2850 // be removed.
2851 enqueueOperands(castOp);
2852 castOp.replaceAllUsesWith(values: nextCast.getInputs());
2853 if (remainingCastOps)
2854 erasedOps.insert(V: castOp.getOperation());
2855 castOp->erase();
2856 break;
2857 }
2858 nextCast = getInputCast(nextCast);
2859 }
2860 }
2861
2862 if (remainingCastOps)
2863 for (UnrealizedConversionCastOp op : castOps)
2864 if (!erasedOps.contains(V: op.getOperation()))
2865 remainingCastOps->push_back(Elt: op);
2866}
2867
2868//===----------------------------------------------------------------------===//
2869// Type Conversion
2870//===----------------------------------------------------------------------===//
2871
2872void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
2873 ArrayRef<Type> types) {
2874 assert(!types.empty() && "expected valid types");
2875 remapInput(origInputNo, /*newInputNo=*/argTypes.size(), newInputCount: types.size());
2876 addInputs(types);
2877}
2878
2879void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
2880 assert(!types.empty() &&
2881 "1->0 type remappings don't need to be added explicitly");
2882 argTypes.append(in_start: types.begin(), in_end: types.end());
2883}
2884
2885void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2886 unsigned newInputNo,
2887 unsigned newInputCount) {
2888 assert(!remappedInputs[origInputNo] && "input has already been remapped");
2889 assert(newInputCount != 0 && "expected valid input count");
2890 remappedInputs[origInputNo] =
2891 InputMapping{.inputNo: newInputNo, .size: newInputCount, /*replacementValues=*/{}};
2892}
2893
2894void TypeConverter::SignatureConversion::remapInput(
2895 unsigned origInputNo, ArrayRef<Value> replacements) {
2896 assert(!remappedInputs[origInputNo] && "input has already been remapped");
2897 remappedInputs[origInputNo] = InputMapping{
2898 .inputNo: origInputNo, /*size=*/0,
2899 .replacementValues: SmallVector<Value, 1>(replacements.begin(), replacements.end())};
2900}
2901
2902LogicalResult TypeConverter::convertType(Type t,
2903 SmallVectorImpl<Type> &results) const {
2904 assert(t && "expected non-null type");
2905
2906 {
2907 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2908 std::defer_lock);
2909 if (t.getContext()->isMultithreadingEnabled())
2910 cacheReadLock.lock();
2911 auto existingIt = cachedDirectConversions.find(Val: t);
2912 if (existingIt != cachedDirectConversions.end()) {
2913 if (existingIt->second)
2914 results.push_back(Elt: existingIt->second);
2915 return success(IsSuccess: existingIt->second != nullptr);
2916 }
2917 auto multiIt = cachedMultiConversions.find(Val: t);
2918 if (multiIt != cachedMultiConversions.end()) {
2919 results.append(in_start: multiIt->second.begin(), in_end: multiIt->second.end());
2920 return success();
2921 }
2922 }
2923 // Walk the added converters in reverse order to apply the most recently
2924 // registered first.
2925 size_t currentCount = results.size();
2926
2927 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2928 std::defer_lock);
2929
2930 for (const ConversionCallbackFn &converter : llvm::reverse(C: conversions)) {
2931 if (std::optional<LogicalResult> result = converter(t, results)) {
2932 if (t.getContext()->isMultithreadingEnabled())
2933 cacheWriteLock.lock();
2934 if (!succeeded(Result: *result)) {
2935 assert(results.size() == currentCount &&
2936 "failed type conversion should not change results");
2937 cachedDirectConversions.try_emplace(Key: t, Args: nullptr);
2938 return failure();
2939 }
2940 auto newTypes = ArrayRef<Type>(results).drop_front(N: currentCount);
2941 if (newTypes.size() == 1)
2942 cachedDirectConversions.try_emplace(Key: t, Args: newTypes.front());
2943 else
2944 cachedMultiConversions.try_emplace(Key: t, Args: llvm::to_vector<2>(Range&: newTypes));
2945 return success();
2946 } else {
2947 assert(results.size() == currentCount &&
2948 "failed type conversion should not change results");
2949 }
2950 }
2951 return failure();
2952}
2953
2954Type TypeConverter::convertType(Type t) const {
2955 // Use the multi-type result version to convert the type.
2956 SmallVector<Type, 1> results;
2957 if (failed(Result: convertType(t, results)))
2958 return nullptr;
2959
2960 // Check to ensure that only one type was produced.
2961 return results.size() == 1 ? results.front() : nullptr;
2962}
2963
2964LogicalResult
2965TypeConverter::convertTypes(TypeRange types,
2966 SmallVectorImpl<Type> &results) const {
2967 for (Type type : types)
2968 if (failed(Result: convertType(t: type, results)))
2969 return failure();
2970 return success();
2971}
2972
2973bool TypeConverter::isLegal(Type type) const {
2974 return convertType(t: type) == type;
2975}
2976bool TypeConverter::isLegal(Operation *op) const {
2977 return isLegal(range: op->getOperandTypes()) && isLegal(range: op->getResultTypes());
2978}
2979
2980bool TypeConverter::isLegal(Region *region) const {
2981 return llvm::all_of(Range&: *region, P: [this](Block &block) {
2982 return isLegal(range: block.getArgumentTypes());
2983 });
2984}
2985
2986bool TypeConverter::isSignatureLegal(FunctionType ty) const {
2987 return isLegal(range: llvm::concat<const Type>(Ranges: ty.getInputs(), Ranges: ty.getResults()));
2988}
2989
2990LogicalResult
2991TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
2992 SignatureConversion &result) const {
2993 // Try to convert the given input type.
2994 SmallVector<Type, 1> convertedTypes;
2995 if (failed(Result: convertType(t: type, results&: convertedTypes)))
2996 return failure();
2997
2998 // If this argument is being dropped, there is nothing left to do.
2999 if (convertedTypes.empty())
3000 return success();
3001
3002 // Otherwise, add the new inputs.
3003 result.addInputs(origInputNo: inputNo, types: convertedTypes);
3004 return success();
3005}
3006LogicalResult
3007TypeConverter::convertSignatureArgs(TypeRange types,
3008 SignatureConversion &result,
3009 unsigned origInputOffset) const {
3010 for (unsigned i = 0, e = types.size(); i != e; ++i)
3011 if (failed(Result: convertSignatureArg(inputNo: origInputOffset + i, type: types[i], result)))
3012 return failure();
3013 return success();
3014}
3015
3016Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
3017 Location loc, Type resultType,
3018 ValueRange inputs) const {
3019 for (const SourceMaterializationCallbackFn &fn :
3020 llvm::reverse(C: sourceMaterializations))
3021 if (Value result = fn(builder, resultType, inputs, loc))
3022 return result;
3023 return nullptr;
3024}
3025
3026Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
3027 Location loc, Type resultType,
3028 ValueRange inputs,
3029 Type originalType) const {
3030 SmallVector<Value> result = materializeTargetConversion(
3031 builder, loc, resultType: TypeRange(resultType), inputs, originalType);
3032 if (result.empty())
3033 return nullptr;
3034 assert(result.size() == 1 && "expected single result");
3035 return result.front();
3036}
3037
3038SmallVector<Value> TypeConverter::materializeTargetConversion(
3039 OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
3040 Type originalType) const {
3041 for (const TargetMaterializationCallbackFn &fn :
3042 llvm::reverse(C: targetMaterializations)) {
3043 SmallVector<Value> result =
3044 fn(builder, resultTypes, inputs, loc, originalType);
3045 if (result.empty())
3046 continue;
3047 assert(TypeRange(ValueRange(result)) == resultTypes &&
3048 "callback produced incorrect number of values or values with "
3049 "incorrect types");
3050 return result;
3051 }
3052 return {};
3053}
3054
3055std::optional<TypeConverter::SignatureConversion>
3056TypeConverter::convertBlockSignature(Block *block) const {
3057 SignatureConversion conversion(block->getNumArguments());
3058 if (failed(Result: convertSignatureArgs(types: block->getArgumentTypes(), result&: conversion)))
3059 return std::nullopt;
3060 return conversion;
3061}
3062
3063//===----------------------------------------------------------------------===//
3064// Type attribute conversion
3065//===----------------------------------------------------------------------===//
3066TypeConverter::AttributeConversionResult
3067TypeConverter::AttributeConversionResult::result(Attribute attr) {
3068 return AttributeConversionResult(attr, resultTag);
3069}
3070
3071TypeConverter::AttributeConversionResult
3072TypeConverter::AttributeConversionResult::na() {
3073 return AttributeConversionResult(nullptr, naTag);
3074}
3075
3076TypeConverter::AttributeConversionResult
3077TypeConverter::AttributeConversionResult::abort() {
3078 return AttributeConversionResult(nullptr, abortTag);
3079}
3080
3081bool TypeConverter::AttributeConversionResult::hasResult() const {
3082 return impl.getInt() == resultTag;
3083}
3084
3085bool TypeConverter::AttributeConversionResult::isNa() const {
3086 return impl.getInt() == naTag;
3087}
3088
3089bool TypeConverter::AttributeConversionResult::isAbort() const {
3090 return impl.getInt() == abortTag;
3091}
3092
3093Attribute TypeConverter::AttributeConversionResult::getResult() const {
3094 assert(hasResult() && "Cannot get result from N/A or abort");
3095 return impl.getPointer();
3096}
3097
3098std::optional<Attribute>
3099TypeConverter::convertTypeAttribute(Type type, Attribute attr) const {
3100 for (const TypeAttributeConversionCallbackFn &fn :
3101 llvm::reverse(C: typeAttributeConversions)) {
3102 AttributeConversionResult res = fn(type, attr);
3103 if (res.hasResult())
3104 return res.getResult();
3105 if (res.isAbort())
3106 return std::nullopt;
3107 }
3108 return std::nullopt;
3109}
3110
3111//===----------------------------------------------------------------------===//
3112// FunctionOpInterfaceSignatureConversion
3113//===----------------------------------------------------------------------===//
3114
3115static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3116 const TypeConverter &typeConverter,
3117 ConversionPatternRewriter &rewriter) {
3118 FunctionType type = dyn_cast<FunctionType>(Val: funcOp.getFunctionType());
3119 if (!type)
3120 return failure();
3121
3122 // Convert the original function types.
3123 TypeConverter::SignatureConversion result(type.getNumInputs());
3124 SmallVector<Type, 1> newResults;
3125 if (failed(Result: typeConverter.convertSignatureArgs(types: type.getInputs(), result)) ||
3126 failed(Result: typeConverter.convertTypes(types: type.getResults(), results&: newResults)) ||
3127 failed(Result: rewriter.convertRegionTypes(region: &funcOp.getFunctionBody(),
3128 converter: typeConverter, entryConversion: &result)))
3129 return failure();
3130
3131 // Update the function signature in-place.
3132 auto newType = FunctionType::get(context: rewriter.getContext(),
3133 inputs: result.getConvertedTypes(), results: newResults);
3134
3135 rewriter.modifyOpInPlace(root: funcOp, callable: [&] { funcOp.setType(newType); });
3136
3137 return success();
3138}
3139
3140/// Create a default conversion pattern that rewrites the type signature of a
3141/// FunctionOpInterface op. This only supports ops which use FunctionType to
3142/// represent their type.
3143namespace {
3144struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3145 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3146 MLIRContext *ctx,
3147 const TypeConverter &converter)
3148 : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
3149
3150 LogicalResult
3151 matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3152 ConversionPatternRewriter &rewriter) const override {
3153 FunctionOpInterface funcOp = cast<FunctionOpInterface>(Val: op);
3154 return convertFuncOpTypes(funcOp, typeConverter: *typeConverter, rewriter);
3155 }
3156};
3157
3158struct AnyFunctionOpInterfaceSignatureConversion
3159 : public OpInterfaceConversionPattern<FunctionOpInterface> {
3160 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3161
3162 LogicalResult
3163 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3164 ConversionPatternRewriter &rewriter) const override {
3165 return convertFuncOpTypes(funcOp, typeConverter: *typeConverter, rewriter);
3166 }
3167};
3168} // namespace
3169
3170FailureOr<Operation *>
3171mlir::convertOpResultTypes(Operation *op, ValueRange operands,
3172 const TypeConverter &converter,
3173 ConversionPatternRewriter &rewriter) {
3174 assert(op && "Invalid op");
3175 Location loc = op->getLoc();
3176 if (converter.isLegal(op))
3177 return rewriter.notifyMatchFailure(arg&: loc, msg: "op already legal");
3178
3179 OperationState newOp(loc, op->getName());
3180 newOp.addOperands(newOperands: operands);
3181
3182 SmallVector<Type> newResultTypes;
3183 if (failed(Result: converter.convertTypes(types: op->getResultTypes(), results&: newResultTypes)))
3184 return rewriter.notifyMatchFailure(arg&: loc, msg: "couldn't convert return types");
3185
3186 newOp.addTypes(newTypes: newResultTypes);
3187 newOp.addAttributes(newAttributes: op->getAttrs());
3188 return rewriter.create(state: newOp);
3189}
3190
3191void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3192 StringRef functionLikeOpName, RewritePatternSet &patterns,
3193 const TypeConverter &converter) {
3194 patterns.add<FunctionOpInterfaceSignatureConversion>(
3195 arg&: functionLikeOpName, args: patterns.getContext(), args: converter);
3196}
3197
3198void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3199 RewritePatternSet &patterns, const TypeConverter &converter) {
3200 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3201 arg: converter, args: patterns.getContext());
3202}
3203
3204//===----------------------------------------------------------------------===//
3205// ConversionTarget
3206//===----------------------------------------------------------------------===//
3207
3208void ConversionTarget::setOpAction(OperationName op,
3209 LegalizationAction action) {
3210 legalOperations[op].action = action;
3211}
3212
3213void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3214 LegalizationAction action) {
3215 for (StringRef dialect : dialectNames)
3216 legalDialects[dialect] = action;
3217}
3218
3219auto ConversionTarget::getOpAction(OperationName op) const
3220 -> std::optional<LegalizationAction> {
3221 std::optional<LegalizationInfo> info = getOpInfo(op);
3222 return info ? info->action : std::optional<LegalizationAction>();
3223}
3224
3225auto ConversionTarget::isLegal(Operation *op) const
3226 -> std::optional<LegalOpDetails> {
3227 std::optional<LegalizationInfo> info = getOpInfo(op: op->getName());
3228 if (!info)
3229 return std::nullopt;
3230
3231 // Returns true if this operation instance is known to be legal.
3232 auto isOpLegal = [&] {
3233 // Handle dynamic legality either with the provided legality function.
3234 if (info->action == LegalizationAction::Dynamic) {
3235 std::optional<bool> result = info->legalityFn(op);
3236 if (result)
3237 return *result;
3238 }
3239
3240 // Otherwise, the operation is only legal if it was marked 'Legal'.
3241 return info->action == LegalizationAction::Legal;
3242 };
3243 if (!isOpLegal())
3244 return std::nullopt;
3245
3246 // This operation is legal, compute any additional legality information.
3247 LegalOpDetails legalityDetails;
3248 if (info->isRecursivelyLegal) {
3249 auto legalityFnIt = opRecursiveLegalityFns.find(Val: op->getName());
3250 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3251 legalityDetails.isRecursivelyLegal =
3252 legalityFnIt->second(op).value_or(u: true);
3253 } else {
3254 legalityDetails.isRecursivelyLegal = true;
3255 }
3256 }
3257 return legalityDetails;
3258}
3259
3260bool ConversionTarget::isIllegal(Operation *op) const {
3261 std::optional<LegalizationInfo> info = getOpInfo(op: op->getName());
3262 if (!info)
3263 return false;
3264
3265 if (info->action == LegalizationAction::Dynamic) {
3266 std::optional<bool> result = info->legalityFn(op);
3267 if (!result)
3268 return false;
3269
3270 return !(*result);
3271 }
3272
3273 return info->action == LegalizationAction::Illegal;
3274}
3275
3276static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(
3277 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
3278 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
3279 if (!oldCallback)
3280 return newCallback;
3281
3282 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3283 Operation *op) -> std::optional<bool> {
3284 if (std::optional<bool> result = newCl(op))
3285 return *result;
3286
3287 return oldCl(op);
3288 };
3289 return chain;
3290}
3291
3292void ConversionTarget::setLegalityCallback(
3293 OperationName name, const DynamicLegalityCallbackFn &callback) {
3294 assert(callback && "expected valid legality callback");
3295 auto *infoIt = legalOperations.find(Key: name);
3296 assert(infoIt != legalOperations.end() &&
3297 infoIt->second.action == LegalizationAction::Dynamic &&
3298 "expected operation to already be marked as dynamically legal");
3299 infoIt->second.legalityFn =
3300 composeLegalityCallbacks(oldCallback: std::move(infoIt->second.legalityFn), newCallback: callback);
3301}
3302
3303void ConversionTarget::markOpRecursivelyLegal(
3304 OperationName name, const DynamicLegalityCallbackFn &callback) {
3305 auto *infoIt = legalOperations.find(Key: name);
3306 assert(infoIt != legalOperations.end() &&
3307 infoIt->second.action != LegalizationAction::Illegal &&
3308 "expected operation to already be marked as legal");
3309 infoIt->second.isRecursivelyLegal = true;
3310 if (callback)
3311 opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3312 oldCallback: std::move(opRecursiveLegalityFns[name]), newCallback: callback);
3313 else
3314 opRecursiveLegalityFns.erase(Val: name);
3315}
3316
3317void ConversionTarget::setLegalityCallback(
3318 ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3319 assert(callback && "expected valid legality callback");
3320 for (StringRef dialect : dialects)
3321 dialectLegalityFns[dialect] = composeLegalityCallbacks(
3322 oldCallback: std::move(dialectLegalityFns[dialect]), newCallback: callback);
3323}
3324
3325void ConversionTarget::setLegalityCallback(
3326 const DynamicLegalityCallbackFn &callback) {
3327 assert(callback && "expected valid legality callback");
3328 unknownLegalityFn = composeLegalityCallbacks(oldCallback: unknownLegalityFn, newCallback: callback);
3329}
3330
3331auto ConversionTarget::getOpInfo(OperationName op) const
3332 -> std::optional<LegalizationInfo> {
3333 // Check for info for this specific operation.
3334 const auto *it = legalOperations.find(Key: op);
3335 if (it != legalOperations.end())
3336 return it->second;
3337 // Check for info for the parent dialect.
3338 auto dialectIt = legalDialects.find(Key: op.getDialectNamespace());
3339 if (dialectIt != legalDialects.end()) {
3340 DynamicLegalityCallbackFn callback;
3341 auto dialectFn = dialectLegalityFns.find(Key: op.getDialectNamespace());
3342 if (dialectFn != dialectLegalityFns.end())
3343 callback = dialectFn->second;
3344 return LegalizationInfo{.action: dialectIt->second, /*isRecursivelyLegal=*/false,
3345 .legalityFn: callback};
3346 }
3347 // Otherwise, check if we mark unknown operations as dynamic.
3348 if (unknownLegalityFn)
3349 return LegalizationInfo{.action: LegalizationAction::Dynamic,
3350 /*isRecursivelyLegal=*/false, .legalityFn: unknownLegalityFn};
3351 return std::nullopt;
3352}
3353
3354#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3355//===----------------------------------------------------------------------===//
3356// PDL Configuration
3357//===----------------------------------------------------------------------===//
3358
3359void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
3360 auto &rewriterImpl =
3361 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3362 rewriterImpl.currentTypeConverter = getTypeConverter();
3363}
3364
3365void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
3366 auto &rewriterImpl =
3367 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3368 rewriterImpl.currentTypeConverter = nullptr;
3369}
3370
3371/// Remap the given value using the rewriter and the type converter in the
3372/// provided config.
3373static FailureOr<SmallVector<Value>>
3374pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) {
3375 SmallVector<Value> mappedValues;
3376 if (failed(Result: rewriter.getRemappedValues(keys: values, results&: mappedValues)))
3377 return failure();
3378 return std::move(mappedValues);
3379}
3380
3381void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
3382 patterns.getPDLPatterns().registerRewriteFunction(
3383 name: "convertValue",
3384 rewriteFn: [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3385 auto results = pdllConvertValues(
3386 rewriter&: static_cast<ConversionPatternRewriter &>(rewriter), values: value);
3387 if (failed(Result: results))
3388 return failure();
3389 return results->front();
3390 });
3391 patterns.getPDLPatterns().registerRewriteFunction(
3392 name: "convertValues", rewriteFn: [](PatternRewriter &rewriter, ValueRange values) {
3393 return pdllConvertValues(
3394 rewriter&: static_cast<ConversionPatternRewriter &>(rewriter), values);
3395 });
3396 patterns.getPDLPatterns().registerRewriteFunction(
3397 name: "convertType",
3398 rewriteFn: [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3399 auto &rewriterImpl =
3400 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3401 if (const TypeConverter *converter =
3402 rewriterImpl.currentTypeConverter) {
3403 if (Type newType = converter->convertType(t: type))
3404 return newType;
3405 return failure();
3406 }
3407 return type;
3408 });
3409 patterns.getPDLPatterns().registerRewriteFunction(
3410 name: "convertTypes",
3411 rewriteFn: [](PatternRewriter &rewriter,
3412 TypeRange types) -> FailureOr<SmallVector<Type>> {
3413 auto &rewriterImpl =
3414 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3415 const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3416 if (!converter)
3417 return SmallVector<Type>(types);
3418
3419 SmallVector<Type> remappedTypes;
3420 if (failed(Result: converter->convertTypes(types, results&: remappedTypes)))
3421 return failure();
3422 return std::move(remappedTypes);
3423 });
3424}
3425#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3426
3427//===----------------------------------------------------------------------===//
3428// Op Conversion Entry Points
3429//===----------------------------------------------------------------------===//
3430
3431/// This is the type of Action that is dispatched when a conversion is applied.
3432class ApplyConversionAction
3433 : public tracing::ActionImpl<ApplyConversionAction> {
3434public:
3435 using Base = tracing::ActionImpl<ApplyConversionAction>;
3436 ApplyConversionAction(ArrayRef<IRUnit> irUnits) : Base(irUnits) {}
3437 static constexpr StringLiteral tag = "apply-conversion";
3438 static constexpr StringLiteral desc =
3439 "Encapsulate the application of a dialect conversion";
3440
3441 void print(raw_ostream &os) const override { os << tag; }
3442};
3443
3444static LogicalResult applyConversion(ArrayRef<Operation *> ops,
3445 const ConversionTarget &target,
3446 const FrozenRewritePatternSet &patterns,
3447 ConversionConfig config,
3448 OpConversionMode mode) {
3449 if (ops.empty())
3450 return success();
3451 MLIRContext *ctx = ops.front()->getContext();
3452 LogicalResult status = success();
3453 SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
3454 ctx->executeAction<ApplyConversionAction>(
3455 actionFn: [&] {
3456 OperationConverter opConverter(target, patterns, config, mode);
3457 status = opConverter.convertOperations(ops);
3458 },
3459 irUnits);
3460 return status;
3461}
3462
3463//===----------------------------------------------------------------------===//
3464// Partial Conversion
3465//===----------------------------------------------------------------------===//
3466
3467LogicalResult mlir::applyPartialConversion(
3468 ArrayRef<Operation *> ops, const ConversionTarget &target,
3469 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3470 return applyConversion(ops, target, patterns, config,
3471 mode: OpConversionMode::Partial);
3472}
3473LogicalResult
3474mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
3475 const FrozenRewritePatternSet &patterns,
3476 ConversionConfig config) {
3477 return applyPartialConversion(ops: llvm::ArrayRef(op), target, patterns, config);
3478}
3479
3480//===----------------------------------------------------------------------===//
3481// Full Conversion
3482//===----------------------------------------------------------------------===//
3483
3484LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
3485 const ConversionTarget &target,
3486 const FrozenRewritePatternSet &patterns,
3487 ConversionConfig config) {
3488 return applyConversion(ops, target, patterns, config, mode: OpConversionMode::Full);
3489}
3490LogicalResult mlir::applyFullConversion(Operation *op,
3491 const ConversionTarget &target,
3492 const FrozenRewritePatternSet &patterns,
3493 ConversionConfig config) {
3494 return applyFullConversion(ops: llvm::ArrayRef(op), target, patterns, config);
3495}
3496
3497//===----------------------------------------------------------------------===//
3498// Analysis Conversion
3499//===----------------------------------------------------------------------===//
3500
3501/// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
3502/// op is a top-level module op (which is expected to be isolated from above),
3503/// return that op.
3504static Operation *findCommonAncestor(ArrayRef<Operation *> ops) {
3505 // Check if there is a top-level operation within `ops`. If so, return that
3506 // op.
3507 for (Operation *op : ops) {
3508 if (!op->getParentOp()) {
3509#ifndef NDEBUG
3510 assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
3511 "expected top-level op to be isolated from above");
3512 for (Operation *other : ops)
3513 assert(op->isAncestor(other) &&
3514 "expected ops to have a common ancestor");
3515#endif // NDEBUG
3516 return op;
3517 }
3518 }
3519
3520 // No top-level op. Find a common ancestor.
3521 Operation *commonAncestor =
3522 ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
3523 for (Operation *op : ops.drop_front()) {
3524 while (!commonAncestor->isProperAncestor(other: op)) {
3525 commonAncestor =
3526 commonAncestor->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
3527 assert(commonAncestor &&
3528 "expected to find a common isolated from above ancestor");
3529 }
3530 }
3531
3532 return commonAncestor;
3533}
3534
3535LogicalResult mlir::applyAnalysisConversion(
3536 ArrayRef<Operation *> ops, ConversionTarget &target,
3537 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3538#ifndef NDEBUG
3539 if (config.legalizableOps)
3540 assert(config.legalizableOps->empty() && "expected empty set");
3541#endif // NDEBUG
3542
3543 // Clone closted common ancestor that is isolated from above.
3544 Operation *commonAncestor = findCommonAncestor(ops);
3545 IRMapping mapping;
3546 Operation *clonedAncestor = commonAncestor->clone(mapper&: mapping);
3547 // Compute inverse IR mapping.
3548 DenseMap<Operation *, Operation *> inverseOperationMap;
3549 for (auto &it : mapping.getOperationMap())
3550 inverseOperationMap[it.second] = it.first;
3551
3552 // Convert the cloned operations. The original IR will remain unchanged.
3553 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
3554 C&: ops, F: [&](Operation *op) { return mapping.lookup(from: op); });
3555 LogicalResult status = applyConversion(ops: opsToConvert, target, patterns, config,
3556 mode: OpConversionMode::Analysis);
3557
3558 // Remap `legalizableOps`, so that they point to the original ops and not the
3559 // cloned ops.
3560 if (config.legalizableOps) {
3561 DenseSet<Operation *> originalLegalizableOps;
3562 for (Operation *op : *config.legalizableOps)
3563 originalLegalizableOps.insert(V: inverseOperationMap[op]);
3564 *config.legalizableOps = std::move(originalLegalizableOps);
3565 }
3566
3567 // Erase the cloned IR.
3568 clonedAncestor->erase();
3569 return status;
3570}
3571
3572LogicalResult
3573mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
3574 const FrozenRewritePatternSet &patterns,
3575 ConversionConfig config) {
3576 return applyAnalysisConversion(ops: llvm::ArrayRef(op), target, patterns, config);
3577}
3578

source code of mlir/lib/Transforms/Utils/DialectConversion.cpp