1//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Transforms/DialectConversion.h"
10#include "mlir/Config/mlir-config.h"
11#include "mlir/IR/Block.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/IR/Dominance.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/Iterators.h"
17#include "mlir/Interfaces/FunctionInterfaces.h"
18#include "mlir/Rewrite/PatternApplicator.h"
19#include "llvm/ADT/ScopeExit.h"
20#include "llvm/ADT/SetVector.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/FormatVariadic.h"
24#include "llvm/Support/SaveAndRestore.h"
25#include "llvm/Support/ScopedPrinter.h"
26#include <optional>
27
28using namespace mlir;
29using namespace mlir::detail;
30
31#define DEBUG_TYPE "dialect-conversion"
32
33/// A utility function to log a successful result for the given reason.
34template <typename... Args>
35static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
36 LLVM_DEBUG({
37 os.unindent();
38 os.startLine() << "} -> SUCCESS";
39 if (!fmt.empty())
40 os.getOStream() << " : "
41 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
42 os.getOStream() << "\n";
43 });
44}
45
46/// A utility function to log a failure result for the given reason.
47template <typename... Args>
48static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
49 LLVM_DEBUG({
50 os.unindent();
51 os.startLine() << "} -> FAILURE : "
52 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
53 << "\n";
54 });
55}
56
57/// Helper function that computes an insertion point where the given value is
58/// defined and can be used without a dominance violation.
59static OpBuilder::InsertPoint computeInsertPoint(Value value) {
60 Block *insertBlock = value.getParentBlock();
61 Block::iterator insertPt = insertBlock->begin();
62 if (OpResult inputRes = dyn_cast<OpResult>(Val&: value))
63 insertPt = ++inputRes.getOwner()->getIterator();
64 return OpBuilder::InsertPoint(insertBlock, insertPt);
65}
66
67/// Helper function that computes an insertion point where the given values are
68/// defined and can be used without a dominance violation.
69static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
70 assert(!vals.empty() && "expected at least one value");
71 DominanceInfo domInfo;
72 OpBuilder::InsertPoint pt = computeInsertPoint(value: vals.front());
73 for (Value v : vals.drop_front()) {
74 // Choose the "later" insertion point.
75 OpBuilder::InsertPoint nextPt = computeInsertPoint(value: v);
76 if (domInfo.dominates(aBlock: pt.getBlock(), aIt: pt.getPoint(), bBlock: nextPt.getBlock(),
77 bIt: nextPt.getPoint())) {
78 // pt is before nextPt => choose nextPt.
79 pt = nextPt;
80 } else {
81#ifndef NDEBUG
82 // nextPt should be before pt => choose pt.
83 // If pt, nextPt are no dominance relationship, then there is no valid
84 // insertion point at which all given values are defined.
85 bool dom = domInfo.dominates(aBlock: nextPt.getBlock(), aIt: nextPt.getPoint(),
86 bBlock: pt.getBlock(), bIt: pt.getPoint());
87 assert(dom && "unable to find valid insertion point");
88#endif // NDEBUG
89 }
90 }
91 return pt;
92}
93
94//===----------------------------------------------------------------------===//
95// ConversionValueMapping
96//===----------------------------------------------------------------------===//
97
98/// A vector of SSA values, optimized for the most common case of a single
99/// value.
100using ValueVector = SmallVector<Value, 1>;
101
102namespace {
103
104/// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
105struct ValueVectorMapInfo {
106 static ValueVector getEmptyKey() { return ValueVector{Value()}; }
107 static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; }
108 static ::llvm::hash_code getHashValue(const ValueVector &val) {
109 return ::llvm::hash_combine_range(R: val);
110 }
111 static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
112 return LHS == RHS;
113 }
114};
115
116/// This class wraps a IRMapping to provide recursive lookup
117/// functionality, i.e. we will traverse if the mapped value also has a mapping.
118struct ConversionValueMapping {
119 /// Return "true" if an SSA value is mapped to the given value. May return
120 /// false positives.
121 bool isMappedTo(Value value) const { return mappedTo.contains(V: value); }
122
123 /// Lookup the most recently mapped values with the desired types in the
124 /// mapping.
125 ///
126 /// Special cases:
127 /// - If the desired type range is empty, simply return the most recently
128 /// mapped values.
129 /// - If there is no mapping to the desired types, also return the most
130 /// recently mapped values.
131 /// - If there is no mapping for the given values at all, return the given
132 /// value.
133 ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
134
135 /// Lookup the given value within the map, or return an empty vector if the
136 /// value is not mapped. If it is mapped, this follows the same behavior
137 /// as `lookupOrDefault`.
138 ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
139
140 template <typename T>
141 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
142
143 /// Map a value vector to the one provided.
144 template <typename OldVal, typename NewVal>
145 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
146 map(OldVal &&oldVal, NewVal &&newVal) {
147 LLVM_DEBUG({
148 ValueVector next(newVal);
149 while (true) {
150 assert(next != oldVal && "inserting cyclic mapping");
151 auto it = mapping.find(next);
152 if (it == mapping.end())
153 break;
154 next = it->second;
155 }
156 });
157 mappedTo.insert_range(newVal);
158
159 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
160 }
161
162 /// Map a value vector or single value to the one provided.
163 template <typename OldVal, typename NewVal>
164 std::enable_if_t<!IsValueVector<OldVal>::value ||
165 !IsValueVector<NewVal>::value>
166 map(OldVal &&oldVal, NewVal &&newVal) {
167 if constexpr (IsValueVector<OldVal>{}) {
168 map(std::forward<OldVal>(oldVal), ValueVector{newVal});
169 } else if constexpr (IsValueVector<NewVal>{}) {
170 map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
171 } else {
172 map(oldVal: ValueVector{oldVal}, newVal: ValueVector{newVal});
173 }
174 }
175
176 void map(Value oldVal, SmallVector<Value> &&newVal) {
177 map(oldVal: ValueVector{oldVal}, newVal: ValueVector(std::move(newVal)));
178 }
179
180 /// Drop the last mapping for the given values.
181 void erase(const ValueVector &value) { mapping.erase(Val: value); }
182
183private:
184 /// Current value mappings.
185 DenseMap<ValueVector, ValueVector, ValueVectorMapInfo> mapping;
186
187 /// All SSA values that are mapped to. May contain false positives.
188 DenseSet<Value> mappedTo;
189};
190} // namespace
191
192ValueVector
193ConversionValueMapping::lookupOrDefault(Value from,
194 TypeRange desiredTypes) const {
195 // Try to find the deepest values that have the desired types. If there is no
196 // such mapping, simply return the deepest values.
197 ValueVector desiredValue;
198 ValueVector current{from};
199 do {
200 // Store the current value if the types match.
201 if (TypeRange(ValueRange(current)) == desiredTypes)
202 desiredValue = current;
203
204 // If possible, Replace each value with (one or multiple) mapped values.
205 ValueVector next;
206 for (Value v : current) {
207 auto it = mapping.find(Val: {v});
208 if (it != mapping.end()) {
209 llvm::append_range(C&: next, R: it->second);
210 } else {
211 next.push_back(Elt: v);
212 }
213 }
214 if (next != current) {
215 // If at least one value was replaced, continue the lookup from there.
216 current = std::move(next);
217 continue;
218 }
219
220 // Otherwise: Check if there is a mapping for the entire vector. Such
221 // mappings are materializations. (N:M mapping are not supported for value
222 // replacements.)
223 //
224 // Note: From a correctness point of view, materializations do not have to
225 // be stored (and looked up) in the mapping. But for performance reasons,
226 // we choose to reuse existing IR (when possible) instead of creating it
227 // multiple times.
228 auto it = mapping.find(Val: current);
229 if (it == mapping.end()) {
230 // No mapping found: The lookup stops here.
231 break;
232 }
233 current = it->second;
234 } while (true);
235
236 // If the desired values were found use them, otherwise default to the leaf
237 // values.
238 // Note: If `desiredTypes` is empty, this function always returns `current`.
239 return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
240}
241
242ValueVector ConversionValueMapping::lookupOrNull(Value from,
243 TypeRange desiredTypes) const {
244 ValueVector result = lookupOrDefault(from, desiredTypes);
245 if (result == ValueVector{from} ||
246 (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
247 return {};
248 return result;
249}
250
251//===----------------------------------------------------------------------===//
252// Rewriter and Translation State
253//===----------------------------------------------------------------------===//
254namespace {
255/// This class contains a snapshot of the current conversion rewriter state.
256/// This is useful when saving and undoing a set of rewrites.
257struct RewriterState {
258 RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
259 unsigned numReplacedOps)
260 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
261 numReplacedOps(numReplacedOps) {}
262
263 /// The current number of rewrites performed.
264 unsigned numRewrites;
265
266 /// The current number of ignored operations.
267 unsigned numIgnoredOperations;
268
269 /// The current number of replaced ops that are scheduled for erasure.
270 unsigned numReplacedOps;
271};
272
273//===----------------------------------------------------------------------===//
274// IR rewrites
275//===----------------------------------------------------------------------===//
276
277/// An IR rewrite that can be committed (upon success) or rolled back (upon
278/// failure).
279///
280/// The dialect conversion keeps track of IR modifications (requested by the
281/// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
282/// are directly applied to the IR as the rewriter API is used, some are applied
283/// partially, and some are delayed until the `IRRewrite` objects are committed.
284class IRRewrite {
285public:
286 /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
287 /// Enum values are ordered, so that they can be used in `classof`: first all
288 /// block rewrites, then all operation rewrites.
289 enum class Kind {
290 // Block rewrites
291 CreateBlock,
292 EraseBlock,
293 InlineBlock,
294 MoveBlock,
295 BlockTypeConversion,
296 ReplaceBlockArg,
297 // Operation rewrites
298 MoveOperation,
299 ModifyOperation,
300 ReplaceOperation,
301 CreateOperation,
302 UnresolvedMaterialization
303 };
304
305 virtual ~IRRewrite() = default;
306
307 /// Roll back the rewrite. Operations may be erased during rollback.
308 virtual void rollback() = 0;
309
310 /// Commit the rewrite. At this point, it is certain that the dialect
311 /// conversion will succeed. All IR modifications, except for operation/block
312 /// erasure, must be performed through the given rewriter.
313 ///
314 /// Instead of erasing operations/blocks, they should merely be unlinked
315 /// commit phase and finally be erased during the cleanup phase. This is
316 /// because internal dialect conversion state (such as `mapping`) may still
317 /// be using them.
318 ///
319 /// Any IR modification that was already performed before the commit phase
320 /// (e.g., insertion of an op) must be communicated to the listener that may
321 /// be attached to the given rewriter.
322 virtual void commit(RewriterBase &rewriter) {}
323
324 /// Cleanup operations/blocks. Cleanup is called after commit.
325 virtual void cleanup(RewriterBase &rewriter) {}
326
327 Kind getKind() const { return kind; }
328
329 static bool classof(const IRRewrite *rewrite) { return true; }
330
331protected:
332 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
333 : kind(kind), rewriterImpl(rewriterImpl) {}
334
335 const ConversionConfig &getConfig() const;
336
337 const Kind kind;
338 ConversionPatternRewriterImpl &rewriterImpl;
339};
340
341/// A block rewrite.
342class BlockRewrite : public IRRewrite {
343public:
344 /// Return the block that this rewrite operates on.
345 Block *getBlock() const { return block; }
346
347 static bool classof(const IRRewrite *rewrite) {
348 return rewrite->getKind() >= Kind::CreateBlock &&
349 rewrite->getKind() <= Kind::ReplaceBlockArg;
350 }
351
352protected:
353 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
354 Block *block)
355 : IRRewrite(kind, rewriterImpl), block(block) {}
356
357 // The block that this rewrite operates on.
358 Block *block;
359};
360
361/// Creation of a block. Block creations are immediately reflected in the IR.
362/// There is no extra work to commit the rewrite. During rollback, the newly
363/// created block is erased.
364class CreateBlockRewrite : public BlockRewrite {
365public:
366 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
367 : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
368
369 static bool classof(const IRRewrite *rewrite) {
370 return rewrite->getKind() == Kind::CreateBlock;
371 }
372
373 void commit(RewriterBase &rewriter) override {
374 // The block was already created and inserted. Just inform the listener.
375 if (auto *listener = rewriter.getListener())
376 listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
377 }
378
379 void rollback() override {
380 // Unlink all of the operations within this block, they will be deleted
381 // separately.
382 auto &blockOps = block->getOperations();
383 while (!blockOps.empty())
384 blockOps.remove(IT: blockOps.begin());
385 block->dropAllUses();
386 if (block->getParent())
387 block->erase();
388 else
389 delete block;
390 }
391};
392
393/// Erasure of a block. Block erasures are partially reflected in the IR. Erased
394/// blocks are immediately unlinked, but only erased during cleanup. This makes
395/// it easier to rollback a block erasure: the block is simply inserted into its
396/// original location.
397class EraseBlockRewrite : public BlockRewrite {
398public:
399 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
400 : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
401 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
402
403 static bool classof(const IRRewrite *rewrite) {
404 return rewrite->getKind() == Kind::EraseBlock;
405 }
406
407 ~EraseBlockRewrite() override {
408 assert(!block &&
409 "rewrite was neither rolled back nor committed/cleaned up");
410 }
411
412 void rollback() override {
413 // The block (owned by this rewrite) was not actually erased yet. It was
414 // just unlinked. Put it back into its original position.
415 assert(block && "expected block");
416 auto &blockList = region->getBlocks();
417 Region::iterator before = insertBeforeBlock
418 ? Region::iterator(insertBeforeBlock)
419 : blockList.end();
420 blockList.insert(where: before, New: block);
421 block = nullptr;
422 }
423
424 void commit(RewriterBase &rewriter) override {
425 // Erase the block.
426 assert(block && "expected block");
427 assert(block->empty() && "expected empty block");
428
429 // Notify the listener that the block is about to be erased.
430 if (auto *listener =
431 dyn_cast_or_null<RewriterBase::Listener>(Val: rewriter.getListener()))
432 listener->notifyBlockErased(block);
433 }
434
435 void cleanup(RewriterBase &rewriter) override {
436 // Erase the block.
437 block->dropAllDefinedValueUses();
438 delete block;
439 block = nullptr;
440 }
441
442private:
443 // The region in which this block was previously contained.
444 Region *region;
445
446 // The original successor of this block before it was unlinked. "nullptr" if
447 // this block was the only block in the region.
448 Block *insertBeforeBlock;
449};
450
451/// Inlining of a block. This rewrite is immediately reflected in the IR.
452/// Note: This rewrite represents only the inlining of the operations. The
453/// erasure of the inlined block is a separate rewrite.
454class InlineBlockRewrite : public BlockRewrite {
455public:
456 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
457 Block *sourceBlock, Block::iterator before)
458 : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
459 sourceBlock(sourceBlock),
460 firstInlinedInst(sourceBlock->empty() ? nullptr
461 : &sourceBlock->front()),
462 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
463 // If a listener is attached to the dialect conversion, ops must be moved
464 // one-by-one. When they are moved in bulk, notifications cannot be sent
465 // because the ops that used to be in the source block at the time of the
466 // inlining (before the "commit" phase) are unknown at the time when
467 // notifications are sent (which is during the "commit" phase).
468 assert(!getConfig().listener &&
469 "InlineBlockRewrite not supported if listener is attached");
470 }
471
472 static bool classof(const IRRewrite *rewrite) {
473 return rewrite->getKind() == Kind::InlineBlock;
474 }
475
476 void rollback() override {
477 // Put the operations from the destination block (owned by the rewrite)
478 // back into the source block.
479 if (firstInlinedInst) {
480 assert(lastInlinedInst && "expected operation");
481 sourceBlock->getOperations().splice(where: sourceBlock->begin(),
482 L2&: block->getOperations(),
483 first: Block::iterator(firstInlinedInst),
484 last: ++Block::iterator(lastInlinedInst));
485 }
486 }
487
488private:
489 // The block that originally contained the operations.
490 Block *sourceBlock;
491
492 // The first inlined operation.
493 Operation *firstInlinedInst;
494
495 // The last inlined operation.
496 Operation *lastInlinedInst;
497};
498
499/// Moving of a block. This rewrite is immediately reflected in the IR.
500class MoveBlockRewrite : public BlockRewrite {
501public:
502 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
503 Region *region, Block *insertBeforeBlock)
504 : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
505 insertBeforeBlock(insertBeforeBlock) {}
506
507 static bool classof(const IRRewrite *rewrite) {
508 return rewrite->getKind() == Kind::MoveBlock;
509 }
510
511 void commit(RewriterBase &rewriter) override {
512 // The block was already moved. Just inform the listener.
513 if (auto *listener = rewriter.getListener()) {
514 // Note: `previousIt` cannot be passed because this is a delayed
515 // notification and iterators into past IR state cannot be represented.
516 listener->notifyBlockInserted(block, /*previous=*/region,
517 /*previousIt=*/{});
518 }
519 }
520
521 void rollback() override {
522 // Move the block back to its original position.
523 Region::iterator before =
524 insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
525 region->getBlocks().splice(where: before, L2&: block->getParent()->getBlocks(), N: block);
526 }
527
528private:
529 // The region in which this block was previously contained.
530 Region *region;
531
532 // The original successor of this block before it was moved. "nullptr" if
533 // this block was the only block in the region.
534 Block *insertBeforeBlock;
535};
536
537/// Block type conversion. This rewrite is partially reflected in the IR.
538class BlockTypeConversionRewrite : public BlockRewrite {
539public:
540 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
541 Block *origBlock, Block *newBlock)
542 : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
543 newBlock(newBlock) {}
544
545 static bool classof(const IRRewrite *rewrite) {
546 return rewrite->getKind() == Kind::BlockTypeConversion;
547 }
548
549 Block *getOrigBlock() const { return block; }
550
551 Block *getNewBlock() const { return newBlock; }
552
553 void commit(RewriterBase &rewriter) override;
554
555 void rollback() override;
556
557private:
558 /// The new block that was created as part of this signature conversion.
559 Block *newBlock;
560};
561
562/// Replacing a block argument. This rewrite is not immediately reflected in the
563/// IR. An internal IR mapping is updated, but the actual replacement is delayed
564/// until the rewrite is committed.
565class ReplaceBlockArgRewrite : public BlockRewrite {
566public:
567 ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
568 Block *block, BlockArgument arg,
569 const TypeConverter *converter)
570 : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
571 converter(converter) {}
572
573 static bool classof(const IRRewrite *rewrite) {
574 return rewrite->getKind() == Kind::ReplaceBlockArg;
575 }
576
577 void commit(RewriterBase &rewriter) override;
578
579 void rollback() override;
580
581private:
582 BlockArgument arg;
583
584 /// The current type converter when the block argument was replaced.
585 const TypeConverter *converter;
586};
587
588/// An operation rewrite.
589class OperationRewrite : public IRRewrite {
590public:
591 /// Return the operation that this rewrite operates on.
592 Operation *getOperation() const { return op; }
593
594 static bool classof(const IRRewrite *rewrite) {
595 return rewrite->getKind() >= Kind::MoveOperation &&
596 rewrite->getKind() <= Kind::UnresolvedMaterialization;
597 }
598
599protected:
600 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
601 Operation *op)
602 : IRRewrite(kind, rewriterImpl), op(op) {}
603
604 // The operation that this rewrite operates on.
605 Operation *op;
606};
607
608/// Moving of an operation. This rewrite is immediately reflected in the IR.
609class MoveOperationRewrite : public OperationRewrite {
610public:
611 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
612 Operation *op, Block *block, Operation *insertBeforeOp)
613 : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
614 insertBeforeOp(insertBeforeOp) {}
615
616 static bool classof(const IRRewrite *rewrite) {
617 return rewrite->getKind() == Kind::MoveOperation;
618 }
619
620 void commit(RewriterBase &rewriter) override {
621 // The operation was already moved. Just inform the listener.
622 if (auto *listener = rewriter.getListener()) {
623 // Note: `previousIt` cannot be passed because this is a delayed
624 // notification and iterators into past IR state cannot be represented.
625 listener->notifyOperationInserted(
626 op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
627 /*insertPt=*/{}));
628 }
629 }
630
631 void rollback() override {
632 // Move the operation back to its original position.
633 Block::iterator before =
634 insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
635 block->getOperations().splice(where: before, L2&: op->getBlock()->getOperations(), N: op);
636 }
637
638private:
639 // The block in which this operation was previously contained.
640 Block *block;
641
642 // The original successor of this operation before it was moved. "nullptr"
643 // if this operation was the only operation in the region.
644 Operation *insertBeforeOp;
645};
646
647/// In-place modification of an op. This rewrite is immediately reflected in
648/// the IR. The previous state of the operation is stored in this object.
649class ModifyOperationRewrite : public OperationRewrite {
650public:
651 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
652 Operation *op)
653 : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
654 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
655 operands(op->operand_begin(), op->operand_end()),
656 successors(op->successor_begin(), op->successor_end()) {
657 if (OpaqueProperties prop = op->getPropertiesStorage()) {
658 // Make a copy of the properties.
659 propertiesStorage = operator new(op->getPropertiesStorageSize());
660 OpaqueProperties propCopy(propertiesStorage);
661 name.initOpProperties(storage: propCopy, /*init=*/prop);
662 }
663 }
664
665 static bool classof(const IRRewrite *rewrite) {
666 return rewrite->getKind() == Kind::ModifyOperation;
667 }
668
669 ~ModifyOperationRewrite() override {
670 assert(!propertiesStorage &&
671 "rewrite was neither committed nor rolled back");
672 }
673
674 void commit(RewriterBase &rewriter) override {
675 // Notify the listener that the operation was modified in-place.
676 if (auto *listener =
677 dyn_cast_or_null<RewriterBase::Listener>(Val: rewriter.getListener()))
678 listener->notifyOperationModified(op);
679
680 if (propertiesStorage) {
681 OpaqueProperties propCopy(propertiesStorage);
682 // Note: The operation may have been erased in the mean time, so
683 // OperationName must be stored in this object.
684 name.destroyOpProperties(properties: propCopy);
685 operator delete(propertiesStorage);
686 propertiesStorage = nullptr;
687 }
688 }
689
690 void rollback() override {
691 op->setLoc(loc);
692 op->setAttrs(attrs);
693 op->setOperands(operands);
694 for (const auto &it : llvm::enumerate(First&: successors))
695 op->setSuccessor(block: it.value(), index: it.index());
696 if (propertiesStorage) {
697 OpaqueProperties propCopy(propertiesStorage);
698 op->copyProperties(rhs: propCopy);
699 name.destroyOpProperties(properties: propCopy);
700 operator delete(propertiesStorage);
701 propertiesStorage = nullptr;
702 }
703 }
704
705private:
706 OperationName name;
707 LocationAttr loc;
708 DictionaryAttr attrs;
709 SmallVector<Value, 8> operands;
710 SmallVector<Block *, 2> successors;
711 void *propertiesStorage = nullptr;
712};
713
714/// Replacing an operation. Erasing an operation is treated as a special case
715/// with "null" replacements. This rewrite is not immediately reflected in the
716/// IR. An internal IR mapping is updated, but values are not replaced and the
717/// original op is not erased until the rewrite is committed.
718class ReplaceOperationRewrite : public OperationRewrite {
719public:
720 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
721 Operation *op, const TypeConverter *converter)
722 : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
723 converter(converter) {}
724
725 static bool classof(const IRRewrite *rewrite) {
726 return rewrite->getKind() == Kind::ReplaceOperation;
727 }
728
729 void commit(RewriterBase &rewriter) override;
730
731 void rollback() override;
732
733 void cleanup(RewriterBase &rewriter) override;
734
735private:
736 /// An optional type converter that can be used to materialize conversions
737 /// between the new and old values if necessary.
738 const TypeConverter *converter;
739};
740
741class CreateOperationRewrite : public OperationRewrite {
742public:
743 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
744 Operation *op)
745 : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
746
747 static bool classof(const IRRewrite *rewrite) {
748 return rewrite->getKind() == Kind::CreateOperation;
749 }
750
751 void commit(RewriterBase &rewriter) override {
752 // The operation was already created and inserted. Just inform the listener.
753 if (auto *listener = rewriter.getListener())
754 listener->notifyOperationInserted(op, /*previous=*/{});
755 }
756
757 void rollback() override;
758};
759
760/// The type of materialization.
761enum MaterializationKind {
762 /// This materialization materializes a conversion from an illegal type to a
763 /// legal one.
764 Target,
765
766 /// This materialization materializes a conversion from a legal type back to
767 /// an illegal one.
768 Source
769};
770
771/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
772/// op. Unresolved materializations are erased at the end of the dialect
773/// conversion.
774class UnresolvedMaterializationRewrite : public OperationRewrite {
775public:
776 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
777 UnrealizedConversionCastOp op,
778 const TypeConverter *converter,
779 MaterializationKind kind, Type originalType,
780 ValueVector mappedValues);
781
782 static bool classof(const IRRewrite *rewrite) {
783 return rewrite->getKind() == Kind::UnresolvedMaterialization;
784 }
785
786 void rollback() override;
787
788 UnrealizedConversionCastOp getOperation() const {
789 return cast<UnrealizedConversionCastOp>(op);
790 }
791
792 /// Return the type converter of this materialization (which may be null).
793 const TypeConverter *getConverter() const {
794 return converterAndKind.getPointer();
795 }
796
797 /// Return the kind of this materialization.
798 MaterializationKind getMaterializationKind() const {
799 return converterAndKind.getInt();
800 }
801
802 /// Return the original type of the SSA value.
803 Type getOriginalType() const { return originalType; }
804
805private:
806 /// The corresponding type converter to use when resolving this
807 /// materialization, and the kind of this materialization.
808 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
809 converterAndKind;
810
811 /// The original type of the SSA value. Only used for target
812 /// materializations.
813 Type originalType;
814
815 /// The values in the conversion value mapping that are being replaced by the
816 /// results of this unresolved materialization.
817 ValueVector mappedValues;
818};
819} // namespace
820
821#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
822/// Return "true" if there is an operation rewrite that matches the specified
823/// rewrite type and operation among the given rewrites.
824template <typename RewriteTy, typename R>
825static bool hasRewrite(R &&rewrites, Operation *op) {
826 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
827 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
828 return rewriteTy && rewriteTy->getOperation() == op;
829 });
830}
831
832/// Return "true" if there is a block rewrite that matches the specified
833/// rewrite type and block among the given rewrites.
834template <typename RewriteTy, typename R>
835static bool hasRewrite(R &&rewrites, Block *block) {
836 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
837 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
838 return rewriteTy && rewriteTy->getBlock() == block;
839 });
840}
841#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
842
843//===----------------------------------------------------------------------===//
844// ConversionPatternRewriterImpl
845//===----------------------------------------------------------------------===//
846namespace mlir {
847namespace detail {
848struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
849 explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
850 const ConversionConfig &config)
851 : context(ctx), eraseRewriter(ctx), config(config) {}
852
853 //===--------------------------------------------------------------------===//
854 // State Management
855 //===--------------------------------------------------------------------===//
856
857 /// Return the current state of the rewriter.
858 RewriterState getCurrentState();
859
860 /// Apply all requested operation rewrites. This method is invoked when the
861 /// conversion process succeeds.
862 void applyRewrites();
863
864 /// Reset the state of the rewriter to a previously saved point. Optionally,
865 /// the name of the pattern that triggered the rollback can specified for
866 /// debugging purposes.
867 void resetState(RewriterState state, StringRef patternName = "");
868
869 /// Append a rewrite. Rewrites are committed upon success and rolled back upon
870 /// failure.
871 template <typename RewriteTy, typename... Args>
872 void appendRewrite(Args &&...args) {
873 rewrites.push_back(
874 std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
875 }
876
877 /// Undo the rewrites (motions, splits) one by one in reverse order until
878 /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
879 /// that triggered the rollback can specified for debugging purposes.
880 void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
881
882 /// Remap the given values to those with potentially different types. Returns
883 /// success if the values could be remapped, failure otherwise. `valueDiagTag`
884 /// is the tag used when describing a value within a diagnostic, e.g.
885 /// "operand".
886 LogicalResult remapValues(StringRef valueDiagTag,
887 std::optional<Location> inputLoc,
888 PatternRewriter &rewriter, ValueRange values,
889 SmallVector<ValueVector> &remapped);
890
891 /// Return "true" if the given operation is ignored, and does not need to be
892 /// converted.
893 bool isOpIgnored(Operation *op) const;
894
895 /// Return "true" if the given operation was replaced or erased.
896 bool wasOpReplaced(Operation *op) const;
897
898 //===--------------------------------------------------------------------===//
899 // Type Conversion
900 //===--------------------------------------------------------------------===//
901
902 /// Convert the types of block arguments within the given region.
903 FailureOr<Block *>
904 convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
905 const TypeConverter &converter,
906 TypeConverter::SignatureConversion *entryConversion);
907
908 /// Apply the given signature conversion on the given block. The new block
909 /// containing the updated signature is returned. If no conversions were
910 /// necessary, e.g. if the block has no arguments, `block` is returned.
911 /// `converter` is used to generate any necessary cast operations that
912 /// translate between the origin argument types and those specified in the
913 /// signature conversion.
914 Block *applySignatureConversion(
915 ConversionPatternRewriter &rewriter, Block *block,
916 const TypeConverter *converter,
917 TypeConverter::SignatureConversion &signatureConversion);
918
919 //===--------------------------------------------------------------------===//
920 // Materializations
921 //===--------------------------------------------------------------------===//
922
923 /// Build an unresolved materialization operation given a range of output
924 /// types and a list of input operands. Returns the inputs if they their
925 /// types match the output types.
926 ///
927 /// If a cast op was built, it can optionally be returned with the `castOp`
928 /// output argument.
929 ///
930 /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
931 /// the results of the unresolved materialization in the conversion value
932 /// mapping.
933 ValueRange buildUnresolvedMaterialization(
934 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
935 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
936 Type originalType, const TypeConverter *converter,
937 UnrealizedConversionCastOp *castOp = nullptr);
938
939 /// Find a replacement value for the given SSA value in the conversion value
940 /// mapping. The replacement value must have the same type as the given SSA
941 /// value. If there is no replacement value with the correct type, find the
942 /// latest replacement value (regardless of the type) and build a source
943 /// materialization.
944 Value findOrBuildReplacementValue(Value value,
945 const TypeConverter *converter);
946
947 //===--------------------------------------------------------------------===//
948 // Rewriter Notification Hooks
949 //===--------------------------------------------------------------------===//
950
951 //// Notifies that an op was inserted.
952 void notifyOperationInserted(Operation *op,
953 OpBuilder::InsertPoint previous) override;
954
955 /// Notifies that an op is about to be replaced with the given values.
956 void notifyOpReplaced(Operation *op,
957 SmallVector<SmallVector<Value>> &&newValues);
958
959 /// Notifies that a block is about to be erased.
960 void notifyBlockIsBeingErased(Block *block);
961
962 /// Notifies that a block was inserted.
963 void notifyBlockInserted(Block *block, Region *previous,
964 Region::iterator previousIt) override;
965
966 /// Notifies that a block is being inlined into another block.
967 void notifyBlockBeingInlined(Block *block, Block *srcBlock,
968 Block::iterator before);
969
970 /// Notifies that a pattern match failed for the given reason.
971 void
972 notifyMatchFailure(Location loc,
973 function_ref<void(Diagnostic &)> reasonCallback) override;
974
975 //===--------------------------------------------------------------------===//
976 // IR Erasure
977 //===--------------------------------------------------------------------===//
978
979 /// A rewriter that keeps track of erased ops and blocks. It ensures that no
980 /// operation or block is erased multiple times. This rewriter assumes that
981 /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
982 struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
983 public:
984 SingleEraseRewriter(MLIRContext *context)
985 : RewriterBase(context, /*listener=*/this) {}
986
987 /// Erase the given op (unless it was already erased).
988 void eraseOp(Operation *op) override {
989 if (wasErased(ptr: op))
990 return;
991 op->dropAllUses();
992 RewriterBase::eraseOp(op);
993 }
994
995 /// Erase the given block (unless it was already erased).
996 void eraseBlock(Block *block) override {
997 if (wasErased(ptr: block))
998 return;
999 assert(block->empty() && "expected empty block");
1000 block->dropAllDefinedValueUses();
1001 RewriterBase::eraseBlock(block);
1002 }
1003
1004 bool wasErased(void *ptr) const { return erased.contains(V: ptr); }
1005
1006 void notifyOperationErased(Operation *op) override { erased.insert(V: op); }
1007
1008 void notifyBlockErased(Block *block) override { erased.insert(V: block); }
1009
1010 private:
1011 /// Pointers to all erased operations and blocks.
1012 DenseSet<void *> erased;
1013 };
1014
1015 //===--------------------------------------------------------------------===//
1016 // State
1017 //===--------------------------------------------------------------------===//
1018
1019 /// MLIR context.
1020 MLIRContext *context;
1021
1022 /// A rewriter that keeps track of ops/block that were already erased and
1023 /// skips duplicate op/block erasures. This rewriter is used during the
1024 /// "cleanup" phase.
1025 SingleEraseRewriter eraseRewriter;
1026
1027 // Mapping between replaced values that differ in type. This happens when
1028 // replacing a value with one of a different type.
1029 ConversionValueMapping mapping;
1030
1031 /// Ordered list of block operations (creations, splits, motions).
1032 SmallVector<std::unique_ptr<IRRewrite>> rewrites;
1033
1034 /// A set of operations that should no longer be considered for legalization.
1035 /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1036 /// tracked separately.
1037 SetVector<Operation *> ignoredOps;
1038
1039 /// A set of operations that were replaced/erased. Such ops are not erased
1040 /// immediately but only when the dialect conversion succeeds. In the mean
1041 /// time, they should no longer be considered for legalization and any attempt
1042 /// to modify/access them is invalid rewriter API usage.
1043 SetVector<Operation *> replacedOps;
1044
1045 /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
1046 /// to the corresponding rewrite objects.
1047 DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
1048 unresolvedMaterializations;
1049
1050 /// The current type converter, or nullptr if no type converter is currently
1051 /// active.
1052 const TypeConverter *currentTypeConverter = nullptr;
1053
1054 /// A mapping of regions to type converters that should be used when
1055 /// converting the arguments of blocks within that region.
1056 DenseMap<Region *, const TypeConverter *> regionToConverter;
1057
1058 /// Dialect conversion configuration.
1059 const ConversionConfig &config;
1060
1061#ifndef NDEBUG
1062 /// A set of operations that have pending updates. This tracking isn't
1063 /// strictly necessary, and is thus only active during debug builds for extra
1064 /// verification.
1065 SmallPtrSet<Operation *, 1> pendingRootUpdates;
1066
1067 /// A logger used to emit diagnostics during the conversion process.
1068 llvm::ScopedPrinter logger{llvm::dbgs()};
1069#endif
1070};
1071} // namespace detail
1072} // namespace mlir
1073
1074const ConversionConfig &IRRewrite::getConfig() const {
1075 return rewriterImpl.config;
1076}
1077
1078void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1079 // Inform the listener about all IR modifications that have already taken
1080 // place: References to the original block have been replaced with the new
1081 // block.
1082 if (auto *listener =
1083 dyn_cast_or_null<RewriterBase::Listener>(Val: rewriter.getListener()))
1084 for (Operation *op : getNewBlock()->getUsers())
1085 listener->notifyOperationModified(op);
1086}
1087
1088void BlockTypeConversionRewrite::rollback() {
1089 getNewBlock()->replaceAllUsesWith(newValue: getOrigBlock());
1090}
1091
1092void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1093 Value repl = rewriterImpl.findOrBuildReplacementValue(value: arg, converter);
1094 if (!repl)
1095 return;
1096
1097 if (isa<BlockArgument>(Val: repl)) {
1098 rewriter.replaceAllUsesWith(from: arg, to: repl);
1099 return;
1100 }
1101
1102 // If the replacement value is an operation, we check to make sure that we
1103 // don't replace uses that are within the parent operation of the
1104 // replacement value.
1105 Operation *replOp = cast<OpResult>(Val&: repl).getOwner();
1106 Block *replBlock = replOp->getBlock();
1107 rewriter.replaceUsesWithIf(from: arg, to: repl, functor: [&](OpOperand &operand) {
1108 Operation *user = operand.getOwner();
1109 return user->getBlock() != replBlock || replOp->isBeforeInBlock(other: user);
1110 });
1111}
1112
1113void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(value: {arg}); }
1114
1115void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1116 auto *listener =
1117 dyn_cast_or_null<RewriterBase::Listener>(Val: rewriter.getListener());
1118
1119 // Compute replacement values.
1120 SmallVector<Value> replacements =
1121 llvm::map_to_vector(C: op->getResults(), F: [&](OpResult result) {
1122 return rewriterImpl.findOrBuildReplacementValue(value: result, converter);
1123 });
1124
1125 // Notify the listener that the operation is about to be replaced.
1126 if (listener)
1127 listener->notifyOperationReplaced(op, replacement: replacements);
1128
1129 // Replace all uses with the new values.
1130 for (auto [result, newValue] :
1131 llvm::zip_equal(t: op->getResults(), u&: replacements))
1132 if (newValue)
1133 rewriter.replaceAllUsesWith(from: result, to: newValue);
1134
1135 // The original op will be erased, so remove it from the set of unlegalized
1136 // ops.
1137 if (getConfig().unlegalizedOps)
1138 getConfig().unlegalizedOps->erase(V: op);
1139
1140 // Notify the listener that the operation (and its nested operations) was
1141 // erased.
1142 if (listener) {
1143 op->walk<WalkOrder::PostOrder>(
1144 callback: [&](Operation *op) { listener->notifyOperationErased(op); });
1145 }
1146
1147 // Do not erase the operation yet. It may still be referenced in `mapping`.
1148 // Just unlink it for now and erase it during cleanup.
1149 op->getBlock()->getOperations().remove(IT: op);
1150}
1151
1152void ReplaceOperationRewrite::rollback() {
1153 for (auto result : op->getResults())
1154 rewriterImpl.mapping.erase(value: {result});
1155}
1156
1157void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1158 rewriter.eraseOp(op);
1159}
1160
1161void CreateOperationRewrite::rollback() {
1162 for (Region &region : op->getRegions()) {
1163 while (!region.getBlocks().empty())
1164 region.getBlocks().remove(IT: region.getBlocks().begin());
1165 }
1166 op->dropAllUses();
1167 op->erase();
1168}
1169
1170UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1171 ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1172 const TypeConverter *converter, MaterializationKind kind, Type originalType,
1173 ValueVector mappedValues)
1174 : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1175 converterAndKind(converter, kind), originalType(originalType),
1176 mappedValues(std::move(mappedValues)) {
1177 assert((!originalType || kind == MaterializationKind::Target) &&
1178 "original type is valid only for target materializations");
1179 rewriterImpl.unresolvedMaterializations[op] = this;
1180}
1181
1182void UnresolvedMaterializationRewrite::rollback() {
1183 if (!mappedValues.empty())
1184 rewriterImpl.mapping.erase(value: mappedValues);
1185 rewriterImpl.unresolvedMaterializations.erase(getOperation());
1186 op->erase();
1187}
1188
1189void ConversionPatternRewriterImpl::applyRewrites() {
1190 // Commit all rewrites.
1191 IRRewriter rewriter(context, config.listener);
1192 // Note: New rewrites may be added during the "commit" phase and the
1193 // `rewrites` vector may reallocate.
1194 for (size_t i = 0; i < rewrites.size(); ++i)
1195 rewrites[i]->commit(rewriter);
1196
1197 // Clean up all rewrites.
1198 for (auto &rewrite : rewrites)
1199 rewrite->cleanup(rewriter&: eraseRewriter);
1200}
1201
1202//===----------------------------------------------------------------------===//
1203// State Management
1204//===----------------------------------------------------------------------===//
1205
1206RewriterState ConversionPatternRewriterImpl::getCurrentState() {
1207 return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1208}
1209
1210void ConversionPatternRewriterImpl::resetState(RewriterState state,
1211 StringRef patternName) {
1212 // Undo any rewrites.
1213 undoRewrites(numRewritesToKeep: state.numRewrites, patternName);
1214
1215 // Pop all of the recorded ignored operations that are no longer valid.
1216 while (ignoredOps.size() != state.numIgnoredOperations)
1217 ignoredOps.pop_back();
1218
1219 while (replacedOps.size() != state.numReplacedOps)
1220 replacedOps.pop_back();
1221}
1222
1223void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
1224 StringRef patternName) {
1225 for (auto &rewrite :
1226 llvm::reverse(C: llvm::drop_begin(RangeOrContainer&: rewrites, N: numRewritesToKeep))) {
1227 if (!config.allowPatternRollback &&
1228 !isa<UnresolvedMaterializationRewrite>(Val: rewrite)) {
1229 // Unresolved materializations can always be rolled back (erased).
1230 llvm::report_fatal_error(reason: "pattern '" + patternName +
1231 "' rollback of IR modifications requested");
1232 }
1233 rewrite->rollback();
1234 }
1235 rewrites.resize(N: numRewritesToKeep);
1236}
1237
1238LogicalResult ConversionPatternRewriterImpl::remapValues(
1239 StringRef valueDiagTag, std::optional<Location> inputLoc,
1240 PatternRewriter &rewriter, ValueRange values,
1241 SmallVector<ValueVector> &remapped) {
1242 remapped.reserve(N: llvm::size(Range&: values));
1243
1244 for (const auto &it : llvm::enumerate(First&: values)) {
1245 Value operand = it.value();
1246 Type origType = operand.getType();
1247 Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1248
1249 if (!currentTypeConverter) {
1250 // The current pattern does not have a type converter. I.e., it does not
1251 // distinguish between legal and illegal types. For each operand, simply
1252 // pass through the most recently mapped values.
1253 remapped.push_back(Elt: mapping.lookupOrDefault(from: operand));
1254 continue;
1255 }
1256
1257 // If there is no legal conversion, fail to match this pattern.
1258 SmallVector<Type, 1> legalTypes;
1259 if (failed(Result: currentTypeConverter->convertType(t: origType, results&: legalTypes))) {
1260 notifyMatchFailure(loc: operandLoc, reasonCallback: [=](Diagnostic &diag) {
1261 diag << "unable to convert type for " << valueDiagTag << " #"
1262 << it.index() << ", type was " << origType;
1263 });
1264 return failure();
1265 }
1266 // If a type is converted to 0 types, there is nothing to do.
1267 if (legalTypes.empty()) {
1268 remapped.push_back(Elt: {});
1269 continue;
1270 }
1271
1272 ValueVector repl = mapping.lookupOrDefault(from: operand, desiredTypes: legalTypes);
1273 if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1274 // Mapped values have the correct type or there is an existing
1275 // materialization. Or the operand is not mapped at all and has the
1276 // correct type.
1277 remapped.push_back(Elt: std::move(repl));
1278 continue;
1279 }
1280
1281 // Create a materialization for the most recently mapped values.
1282 repl = mapping.lookupOrDefault(from: operand);
1283 ValueRange castValues = buildUnresolvedMaterialization(
1284 MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1285 /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1286 /*originalType=*/origType, currentTypeConverter);
1287 remapped.push_back(Elt: castValues);
1288 }
1289 return success();
1290}
1291
1292bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
1293 // Check to see if this operation is ignored or was replaced.
1294 return replacedOps.count(key: op) || ignoredOps.count(key: op);
1295}
1296
1297bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
1298 // Check to see if this operation was replaced.
1299 return replacedOps.count(key: op);
1300}
1301
1302//===----------------------------------------------------------------------===//
1303// Type Conversion
1304//===----------------------------------------------------------------------===//
1305
1306FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
1307 ConversionPatternRewriter &rewriter, Region *region,
1308 const TypeConverter &converter,
1309 TypeConverter::SignatureConversion *entryConversion) {
1310 regionToConverter[region] = &converter;
1311 if (region->empty())
1312 return nullptr;
1313
1314 // Convert the arguments of each non-entry block within the region.
1315 for (Block &block :
1316 llvm::make_early_inc_range(Range: llvm::drop_begin(RangeOrContainer&: *region, N: 1))) {
1317 // Compute the signature for the block with the provided converter.
1318 std::optional<TypeConverter::SignatureConversion> conversion =
1319 converter.convertBlockSignature(block: &block);
1320 if (!conversion)
1321 return failure();
1322 // Convert the block with the computed signature.
1323 applySignatureConversion(rewriter, block: &block, converter: &converter, signatureConversion&: *conversion);
1324 }
1325
1326 // Convert the entry block. If an entry signature conversion was provided,
1327 // use that one. Otherwise, compute the signature with the type converter.
1328 if (entryConversion)
1329 return applySignatureConversion(rewriter, block: &region->front(), converter: &converter,
1330 signatureConversion&: *entryConversion);
1331 std::optional<TypeConverter::SignatureConversion> conversion =
1332 converter.convertBlockSignature(block: &region->front());
1333 if (!conversion)
1334 return failure();
1335 return applySignatureConversion(rewriter, block: &region->front(), converter: &converter,
1336 signatureConversion&: *conversion);
1337}
1338
1339Block *ConversionPatternRewriterImpl::applySignatureConversion(
1340 ConversionPatternRewriter &rewriter, Block *block,
1341 const TypeConverter *converter,
1342 TypeConverter::SignatureConversion &signatureConversion) {
1343#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1344 // A block cannot be converted multiple times.
1345 if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1346 llvm::report_fatal_error("block was already converted");
1347#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1348
1349 OpBuilder::InsertionGuard g(rewriter);
1350
1351 // If no arguments are being changed or added, there is nothing to do.
1352 unsigned origArgCount = block->getNumArguments();
1353 auto convertedTypes = signatureConversion.getConvertedTypes();
1354 if (llvm::equal(LRange: block->getArgumentTypes(), RRange&: convertedTypes))
1355 return block;
1356
1357 // Compute the locations of all block arguments in the new block.
1358 SmallVector<Location> newLocs(convertedTypes.size(),
1359 rewriter.getUnknownLoc());
1360 for (unsigned i = 0; i < origArgCount; ++i) {
1361 auto inputMap = signatureConversion.getInputMapping(input: i);
1362 if (!inputMap || inputMap->replacedWithValues())
1363 continue;
1364 Location origLoc = block->getArgument(i).getLoc();
1365 for (unsigned j = 0; j < inputMap->size; ++j)
1366 newLocs[inputMap->inputNo + j] = origLoc;
1367 }
1368
1369 // Insert a new block with the converted block argument types and move all ops
1370 // from the old block to the new block.
1371 Block *newBlock =
1372 rewriter.createBlock(parent: block->getParent(), insertPt: std::next(x: block->getIterator()),
1373 argTypes: convertedTypes, locs: newLocs);
1374
1375 // If a listener is attached to the dialect conversion, ops cannot be moved
1376 // to the destination block in bulk ("fast path"). This is because at the time
1377 // the notifications are sent, it is unknown which ops were moved. Instead,
1378 // ops should be moved one-by-one ("slow path"), so that a separate
1379 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1380 // a bit more efficient, so we try to do that when possible.
1381 bool fastPath = !config.listener;
1382 if (fastPath) {
1383 appendRewrite<InlineBlockRewrite>(args&: newBlock, args&: block, args: newBlock->end());
1384 newBlock->getOperations().splice(where: newBlock->end(), L2&: block->getOperations());
1385 } else {
1386 while (!block->empty())
1387 rewriter.moveOpBefore(op: &block->front(), block: newBlock, iterator: newBlock->end());
1388 }
1389
1390 // Replace all uses of the old block with the new block.
1391 block->replaceAllUsesWith(newValue&: newBlock);
1392
1393 for (unsigned i = 0; i != origArgCount; ++i) {
1394 BlockArgument origArg = block->getArgument(i);
1395 Type origArgType = origArg.getType();
1396
1397 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1398 signatureConversion.getInputMapping(input: i);
1399 if (!inputMap) {
1400 // This block argument was dropped and no replacement value was provided.
1401 // Materialize a replacement value "out of thin air".
1402 buildUnresolvedMaterialization(
1403 MaterializationKind::Source,
1404 OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1405 /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
1406 /*outputTypes=*/origArgType, /*originalType=*/Type(), converter);
1407 appendRewrite<ReplaceBlockArgRewrite>(args&: block, args&: origArg, args&: converter);
1408 continue;
1409 }
1410
1411 if (inputMap->replacedWithValues()) {
1412 // This block argument was dropped and replacement values were provided.
1413 assert(inputMap->size == 0 &&
1414 "invalid to provide a replacement value when the argument isn't "
1415 "dropped");
1416 mapping.map(oldVal&: origArg, newVal&: inputMap->replacementValues);
1417 appendRewrite<ReplaceBlockArgRewrite>(args&: block, args&: origArg, args&: converter);
1418 continue;
1419 }
1420
1421 // This is a 1->1+ mapping.
1422 auto replArgs =
1423 newBlock->getArguments().slice(N: inputMap->inputNo, M: inputMap->size);
1424 ValueVector replArgVals = llvm::to_vector_of<Value, 1>(Range&: replArgs);
1425 mapping.map(oldVal&: origArg, newVal: std::move(replArgVals));
1426 appendRewrite<ReplaceBlockArgRewrite>(args&: block, args&: origArg, args&: converter);
1427 }
1428
1429 appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/args&: block, args&: newBlock);
1430
1431 // Erase the old block. (It is just unlinked for now and will be erased during
1432 // cleanup.)
1433 rewriter.eraseBlock(block);
1434
1435 return newBlock;
1436}
1437
1438//===----------------------------------------------------------------------===//
1439// Materializations
1440//===----------------------------------------------------------------------===//
1441
1442/// Build an unresolved materialization operation given an output type and set
1443/// of input operands.
1444ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1445 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1446 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1447 Type originalType, const TypeConverter *converter,
1448 UnrealizedConversionCastOp *castOp) {
1449 assert((!originalType || kind == MaterializationKind::Target) &&
1450 "original type is valid only for target materializations");
1451 assert(TypeRange(inputs) != outputTypes &&
1452 "materialization is not necessary");
1453
1454 // Create an unresolved materialization. We use a new OpBuilder to avoid
1455 // tracking the materialization like we do for other operations.
1456 OpBuilder builder(outputTypes.front().getContext());
1457 builder.setInsertionPoint(block: ip.getBlock(), insertPoint: ip.getPoint());
1458 auto convertOp =
1459 builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
1460 if (!valuesToMap.empty())
1461 mapping.map(valuesToMap, convertOp.getResults());
1462 if (castOp)
1463 *castOp = convertOp;
1464 appendRewrite<UnresolvedMaterializationRewrite>(
1465 convertOp, converter, kind, originalType, std::move(valuesToMap));
1466 return convertOp.getResults();
1467}
1468
1469Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
1470 Value value, const TypeConverter *converter) {
1471 // Try to find a replacement value with the same type in the conversion value
1472 // mapping. This includes cached materializations. We try to reuse those
1473 // instead of generating duplicate IR.
1474 ValueVector repl = mapping.lookupOrNull(from: value, desiredTypes: value.getType());
1475 if (!repl.empty())
1476 return repl.front();
1477
1478 // Check if the value is dead. No replacement value is needed in that case.
1479 // This is an approximate check that may have false negatives but does not
1480 // require computing and traversing an inverse mapping. (We may end up
1481 // building source materializations that are never used and that fold away.)
1482 if (llvm::all_of(Range: value.getUsers(),
1483 P: [&](Operation *op) { return replacedOps.contains(key: op); }) &&
1484 !mapping.isMappedTo(value))
1485 return Value();
1486
1487 // No replacement value was found. Get the latest replacement value
1488 // (regardless of the type) and build a source materialization to the
1489 // original type.
1490 repl = mapping.lookupOrNull(from: value);
1491 if (repl.empty()) {
1492 // No replacement value is registered in the mapping. This means that the
1493 // value is dropped and no longer needed. (If the value were still needed,
1494 // a source materialization producing a replacement value "out of thin air"
1495 // would have already been created during `replaceOp` or
1496 // `applySignatureConversion`.)
1497 return Value();
1498 }
1499
1500 // Note: `computeInsertPoint` computes the "earliest" insertion point at
1501 // which all values in `repl` are defined. It is important to emit the
1502 // materialization at that location because the same materialization may be
1503 // reused in a different context. (That's because materializations are cached
1504 // in the conversion value mapping.) The insertion point of the
1505 // materialization must be valid for all future users that may be created
1506 // later in the conversion process.
1507 Value castValue =
1508 buildUnresolvedMaterialization(MaterializationKind::Source,
1509 computeInsertPoint(repl), value.getLoc(),
1510 /*valuesToMap=*/repl, /*inputs=*/repl,
1511 /*outputTypes=*/value.getType(),
1512 /*originalType=*/Type(), converter)
1513 .front();
1514 return castValue;
1515}
1516
1517//===----------------------------------------------------------------------===//
1518// Rewriter Notification Hooks
1519//===----------------------------------------------------------------------===//
1520
1521void ConversionPatternRewriterImpl::notifyOperationInserted(
1522 Operation *op, OpBuilder::InsertPoint previous) {
1523 LLVM_DEBUG({
1524 logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
1525 << ")\n";
1526 });
1527 assert(!wasOpReplaced(op->getParentOp()) &&
1528 "attempting to insert into a block within a replaced/erased op");
1529
1530 if (!previous.isSet()) {
1531 // This is a newly created op.
1532 appendRewrite<CreateOperationRewrite>(args&: op);
1533 return;
1534 }
1535 Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
1536 ? nullptr
1537 : &*previous.getPoint();
1538 appendRewrite<MoveOperationRewrite>(args&: op, args: previous.getBlock(), args&: prevOp);
1539}
1540
1541void ConversionPatternRewriterImpl::notifyOpReplaced(
1542 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1543 assert(newValues.size() == op->getNumResults());
1544 assert(!ignoredOps.contains(op) && "operation was already replaced");
1545
1546 // Check if replaced op is an unresolved materialization, i.e., an
1547 // unrealized_conversion_cast op that was created by the conversion driver.
1548 bool isUnresolvedMaterialization = false;
1549 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1550 if (unresolvedMaterializations.contains(castOp))
1551 isUnresolvedMaterialization = true;
1552
1553 // Create mappings for each of the new result values.
1554 for (auto [repl, result] : llvm::zip_equal(t&: newValues, u: op->getResults())) {
1555 if (repl.empty()) {
1556 // This result was dropped and no replacement value was provided.
1557 if (isUnresolvedMaterialization) {
1558 // Do not create another materializations if we are erasing a
1559 // materialization.
1560 continue;
1561 }
1562
1563 // Materialize a replacement value "out of thin air".
1564 buildUnresolvedMaterialization(
1565 MaterializationKind::Source, computeInsertPoint(result),
1566 result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
1567 /*outputTypes=*/result.getType(), /*originalType=*/Type(),
1568 currentTypeConverter);
1569 continue;
1570 } else {
1571 // Make sure that the user does not mess with unresolved materializations
1572 // that were inserted by the conversion driver. We keep track of these
1573 // ops in internal data structures. Erasing them must be allowed because
1574 // this can happen when the user is erasing an entire block (including
1575 // its body). But replacing them with another value should be forbidden
1576 // to avoid problems with the `mapping`.
1577 assert(!isUnresolvedMaterialization &&
1578 "attempting to replace an unresolved materialization");
1579 }
1580
1581 // Remap result to replacement value.
1582 if (repl.empty())
1583 continue;
1584 mapping.map(oldVal: static_cast<Value>(result), newVal: std::move(repl));
1585 }
1586
1587 appendRewrite<ReplaceOperationRewrite>(args&: op, args&: currentTypeConverter);
1588 // Mark this operation and all nested ops as replaced.
1589 op->walk(callback: [&](Operation *op) { replacedOps.insert(X: op); });
1590}
1591
1592void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
1593 appendRewrite<EraseBlockRewrite>(args&: block);
1594}
1595
1596void ConversionPatternRewriterImpl::notifyBlockInserted(
1597 Block *block, Region *previous, Region::iterator previousIt) {
1598 assert(!wasOpReplaced(block->getParentOp()) &&
1599 "attempting to insert into a region within a replaced/erased op");
1600 LLVM_DEBUG(
1601 {
1602 Operation *parent = block->getParentOp();
1603 if (parent) {
1604 logger.startLine() << "** Insert Block into : '" << parent->getName()
1605 << "'(" << parent << ")\n";
1606 } else {
1607 logger.startLine()
1608 << "** Insert Block into detached Region (nullptr parent op)'\n";
1609 }
1610 });
1611
1612 if (!previous) {
1613 // This is a newly created block.
1614 appendRewrite<CreateBlockRewrite>(args&: block);
1615 return;
1616 }
1617 Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
1618 appendRewrite<MoveBlockRewrite>(args&: block, args&: previous, args&: prevBlock);
1619}
1620
1621void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
1622 Block *block, Block *srcBlock, Block::iterator before) {
1623 appendRewrite<InlineBlockRewrite>(args&: block, args&: srcBlock, args&: before);
1624}
1625
1626void ConversionPatternRewriterImpl::notifyMatchFailure(
1627 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1628 LLVM_DEBUG({
1629 Diagnostic diag(loc, DiagnosticSeverity::Remark);
1630 reasonCallback(diag);
1631 logger.startLine() << "** Failure : " << diag.str() << "\n";
1632 if (config.notifyCallback)
1633 config.notifyCallback(diag);
1634 });
1635}
1636
1637//===----------------------------------------------------------------------===//
1638// ConversionPatternRewriter
1639//===----------------------------------------------------------------------===//
1640
1641ConversionPatternRewriter::ConversionPatternRewriter(
1642 MLIRContext *ctx, const ConversionConfig &config)
1643 : PatternRewriter(ctx),
1644 impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
1645 setListener(impl.get());
1646}
1647
1648ConversionPatternRewriter::~ConversionPatternRewriter() = default;
1649
1650void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
1651 assert(op && newOp && "expected non-null op");
1652 replaceOp(op, newValues: newOp->getResults());
1653}
1654
1655void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
1656 assert(op->getNumResults() == newValues.size() &&
1657 "incorrect # of replacement values");
1658 LLVM_DEBUG({
1659 impl->logger.startLine()
1660 << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1661 });
1662 SmallVector<SmallVector<Value>> newVals =
1663 llvm::map_to_vector(C&: newValues, F: [](Value v) -> SmallVector<Value> {
1664 return v ? SmallVector<Value>{v} : SmallVector<Value>();
1665 });
1666 impl->notifyOpReplaced(op, newValues: std::move(newVals));
1667}
1668
1669void ConversionPatternRewriter::replaceOpWithMultiple(
1670 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1671 assert(op->getNumResults() == newValues.size() &&
1672 "incorrect # of replacement values");
1673 LLVM_DEBUG({
1674 impl->logger.startLine()
1675 << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1676 });
1677 impl->notifyOpReplaced(op, newValues: std::move(newValues));
1678}
1679
1680void ConversionPatternRewriter::eraseOp(Operation *op) {
1681 LLVM_DEBUG({
1682 impl->logger.startLine()
1683 << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1684 });
1685 SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
1686 impl->notifyOpReplaced(op, newValues: std::move(nullRepls));
1687}
1688
1689void ConversionPatternRewriter::eraseBlock(Block *block) {
1690 assert(!impl->wasOpReplaced(block->getParentOp()) &&
1691 "attempting to erase a block within a replaced/erased op");
1692
1693 // Mark all ops for erasure.
1694 for (Operation &op : *block)
1695 eraseOp(op: &op);
1696
1697 // Unlink the block from its parent region. The block is kept in the rewrite
1698 // object and will be actually destroyed when rewrites are applied. This
1699 // allows us to keep the operations in the block live and undo the removal by
1700 // re-inserting the block.
1701 impl->notifyBlockIsBeingErased(block);
1702 block->getParent()->getBlocks().remove(IT: block);
1703}
1704
1705Block *ConversionPatternRewriter::applySignatureConversion(
1706 Block *block, TypeConverter::SignatureConversion &conversion,
1707 const TypeConverter *converter) {
1708 assert(!impl->wasOpReplaced(block->getParentOp()) &&
1709 "attempting to apply a signature conversion to a block within a "
1710 "replaced/erased op");
1711 return impl->applySignatureConversion(rewriter&: *this, block, converter, signatureConversion&: conversion);
1712}
1713
1714FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
1715 Region *region, const TypeConverter &converter,
1716 TypeConverter::SignatureConversion *entryConversion) {
1717 assert(!impl->wasOpReplaced(region->getParentOp()) &&
1718 "attempting to apply a signature conversion to a block within a "
1719 "replaced/erased op");
1720 return impl->convertRegionTypes(rewriter&: *this, region, converter, entryConversion);
1721}
1722
1723void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
1724 Value to) {
1725 LLVM_DEBUG({
1726 impl->logger.startLine() << "** Replace Argument : '" << from << "'";
1727 if (Operation *parentOp = from.getOwner()->getParentOp()) {
1728 impl->logger.getOStream() << " (in region of '" << parentOp->getName()
1729 << "' (" << parentOp << ")\n";
1730 } else {
1731 impl->logger.getOStream() << " (unlinked block)\n";
1732 }
1733 });
1734 impl->appendRewrite<ReplaceBlockArgRewrite>(args: from.getOwner(), args&: from,
1735 args&: impl->currentTypeConverter);
1736 impl->mapping.map(oldVal: impl->mapping.lookupOrDefault(from), newVal&: to);
1737}
1738
1739Value ConversionPatternRewriter::getRemappedValue(Value key) {
1740 SmallVector<ValueVector> remappedValues;
1741 if (failed(Result: impl->remapValues(valueDiagTag: "value", /*inputLoc=*/std::nullopt, rewriter&: *this, values: key,
1742 remapped&: remappedValues)))
1743 return nullptr;
1744 assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
1745 return remappedValues.front().front();
1746}
1747
1748LogicalResult
1749ConversionPatternRewriter::getRemappedValues(ValueRange keys,
1750 SmallVectorImpl<Value> &results) {
1751 if (keys.empty())
1752 return success();
1753 SmallVector<ValueVector> remapped;
1754 if (failed(Result: impl->remapValues(valueDiagTag: "value", /*inputLoc=*/std::nullopt, rewriter&: *this, values: keys,
1755 remapped)))
1756 return failure();
1757 for (const auto &values : remapped) {
1758 assert(values.size() == 1 && "1:N conversion not supported");
1759 results.push_back(Elt: values.front());
1760 }
1761 return success();
1762}
1763
1764void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
1765 Block::iterator before,
1766 ValueRange argValues) {
1767#ifndef NDEBUG
1768 assert(argValues.size() == source->getNumArguments() &&
1769 "incorrect # of argument replacement values");
1770 assert(!impl->wasOpReplaced(source->getParentOp()) &&
1771 "attempting to inline a block from a replaced/erased op");
1772 assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1773 "attempting to inline a block into a replaced/erased op");
1774 auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1775 // The source block will be deleted, so it should not have any users (i.e.,
1776 // there should be no predecessors).
1777 assert(llvm::all_of(source->getUsers(), opIgnored) &&
1778 "expected 'source' to have no predecessors");
1779#endif // NDEBUG
1780
1781 // If a listener is attached to the dialect conversion, ops cannot be moved
1782 // to the destination block in bulk ("fast path"). This is because at the time
1783 // the notifications are sent, it is unknown which ops were moved. Instead,
1784 // ops should be moved one-by-one ("slow path"), so that a separate
1785 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1786 // a bit more efficient, so we try to do that when possible.
1787 bool fastPath = !impl->config.listener;
1788
1789 if (fastPath)
1790 impl->notifyBlockBeingInlined(block: dest, srcBlock: source, before);
1791
1792 // Replace all uses of block arguments.
1793 for (auto it : llvm::zip(t: source->getArguments(), u&: argValues))
1794 replaceUsesOfBlockArgument(from: std::get<0>(t&: it), to: std::get<1>(t&: it));
1795
1796 if (fastPath) {
1797 // Move all ops at once.
1798 dest->getOperations().splice(where: before, L2&: source->getOperations());
1799 } else {
1800 // Move op by op.
1801 while (!source->empty())
1802 moveOpBefore(op: &source->front(), block: dest, iterator: before);
1803 }
1804
1805 // Erase the source block.
1806 eraseBlock(block: source);
1807}
1808
1809void ConversionPatternRewriter::startOpModification(Operation *op) {
1810 assert(!impl->wasOpReplaced(op) &&
1811 "attempting to modify a replaced/erased op");
1812#ifndef NDEBUG
1813 impl->pendingRootUpdates.insert(Ptr: op);
1814#endif
1815 impl->appendRewrite<ModifyOperationRewrite>(args&: op);
1816}
1817
1818void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
1819 assert(!impl->wasOpReplaced(op) &&
1820 "attempting to modify a replaced/erased op");
1821 PatternRewriter::finalizeOpModification(op);
1822 // There is nothing to do here, we only need to track the operation at the
1823 // start of the update.
1824#ifndef NDEBUG
1825 assert(impl->pendingRootUpdates.erase(op) &&
1826 "operation did not have a pending in-place update");
1827#endif
1828}
1829
1830void ConversionPatternRewriter::cancelOpModification(Operation *op) {
1831#ifndef NDEBUG
1832 assert(impl->pendingRootUpdates.erase(op) &&
1833 "operation did not have a pending in-place update");
1834#endif
1835 // Erase the last update for this operation.
1836 auto it = llvm::find_if(
1837 Range: llvm::reverse(C&: impl->rewrites), P: [&](std::unique_ptr<IRRewrite> &rewrite) {
1838 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(Val: rewrite.get());
1839 return modifyRewrite && modifyRewrite->getOperation() == op;
1840 });
1841 assert(it != impl->rewrites.rend() && "no root update started on op");
1842 (*it)->rollback();
1843 int updateIdx = std::prev(x: impl->rewrites.rend()) - it;
1844 impl->rewrites.erase(CI: impl->rewrites.begin() + updateIdx);
1845}
1846
1847detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
1848 return *impl;
1849}
1850
1851//===----------------------------------------------------------------------===//
1852// ConversionPattern
1853//===----------------------------------------------------------------------===//
1854
1855SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands(
1856 ArrayRef<ValueRange> operands) const {
1857 SmallVector<Value> oneToOneOperands;
1858 oneToOneOperands.reserve(N: operands.size());
1859 for (ValueRange operand : operands) {
1860 if (operand.size() != 1)
1861 llvm::report_fatal_error(reason: "pattern '" + getDebugName() +
1862 "' does not support 1:N conversion");
1863 oneToOneOperands.push_back(Elt: operand.front());
1864 }
1865 return oneToOneOperands;
1866}
1867
1868LogicalResult
1869ConversionPattern::matchAndRewrite(Operation *op,
1870 PatternRewriter &rewriter) const {
1871 auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1872 auto &rewriterImpl = dialectRewriter.getImpl();
1873
1874 // Track the current conversion pattern type converter in the rewriter.
1875 llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1876 getTypeConverter());
1877
1878 // Remap the operands of the operation.
1879 SmallVector<ValueVector> remapped;
1880 if (failed(Result: rewriterImpl.remapValues(valueDiagTag: "operand", inputLoc: op->getLoc(), rewriter,
1881 values: op->getOperands(), remapped))) {
1882 return failure();
1883 }
1884 SmallVector<ValueRange> remappedAsRange =
1885 llvm::to_vector_of<ValueRange>(Range&: remapped);
1886 return matchAndRewrite(op, operands: remappedAsRange, rewriter&: dialectRewriter);
1887}
1888
1889//===----------------------------------------------------------------------===//
1890// OperationLegalizer
1891//===----------------------------------------------------------------------===//
1892
1893namespace {
1894/// A set of rewrite patterns that can be used to legalize a given operation.
1895using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1896
1897/// This class defines a recursive operation legalizer.
1898class OperationLegalizer {
1899public:
1900 using LegalizationAction = ConversionTarget::LegalizationAction;
1901
1902 OperationLegalizer(const ConversionTarget &targetInfo,
1903 const FrozenRewritePatternSet &patterns,
1904 const ConversionConfig &config);
1905
1906 /// Returns true if the given operation is known to be illegal on the target.
1907 bool isIllegal(Operation *op) const;
1908
1909 /// Attempt to legalize the given operation. Returns success if the operation
1910 /// was legalized, failure otherwise.
1911 LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1912
1913 /// Returns the conversion target in use by the legalizer.
1914 const ConversionTarget &getTarget() { return target; }
1915
1916private:
1917 /// Attempt to legalize the given operation by folding it.
1918 LogicalResult legalizeWithFold(Operation *op,
1919 ConversionPatternRewriter &rewriter);
1920
1921 /// Attempt to legalize the given operation by applying a pattern. Returns
1922 /// success if the operation was legalized, failure otherwise.
1923 LogicalResult legalizeWithPattern(Operation *op,
1924 ConversionPatternRewriter &rewriter);
1925
1926 /// Return true if the given pattern may be applied to the given operation,
1927 /// false otherwise.
1928 bool canApplyPattern(Operation *op, const Pattern &pattern,
1929 ConversionPatternRewriter &rewriter);
1930
1931 /// Legalize the resultant IR after successfully applying the given pattern.
1932 LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1933 ConversionPatternRewriter &rewriter,
1934 RewriterState &curState);
1935
1936 /// Legalizes the actions registered during the execution of a pattern.
1937 LogicalResult
1938 legalizePatternBlockRewrites(Operation *op,
1939 ConversionPatternRewriter &rewriter,
1940 ConversionPatternRewriterImpl &impl,
1941 RewriterState &state, RewriterState &newState);
1942 LogicalResult legalizePatternCreatedOperations(
1943 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1944 RewriterState &state, RewriterState &newState);
1945 LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1946 ConversionPatternRewriterImpl &impl,
1947 RewriterState &state,
1948 RewriterState &newState);
1949
1950 //===--------------------------------------------------------------------===//
1951 // Cost Model
1952 //===--------------------------------------------------------------------===//
1953
1954 /// Build an optimistic legalization graph given the provided patterns. This
1955 /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1956 /// patterns for operations that are not directly legal, but may be
1957 /// transitively legal for the current target given the provided patterns.
1958 void buildLegalizationGraph(
1959 LegalizationPatterns &anyOpLegalizerPatterns,
1960 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1961
1962 /// Compute the benefit of each node within the computed legalization graph.
1963 /// This orders the patterns within 'legalizerPatterns' based upon two
1964 /// criteria:
1965 /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1966 /// represent the more direct mapping to the target.
1967 /// 2) When comparing patterns with the same legalization depth, prefer the
1968 /// pattern with the highest PatternBenefit. This allows for users to
1969 /// prefer specific legalizations over others.
1970 void computeLegalizationGraphBenefit(
1971 LegalizationPatterns &anyOpLegalizerPatterns,
1972 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1973
1974 /// Compute the legalization depth when legalizing an operation of the given
1975 /// type.
1976 unsigned computeOpLegalizationDepth(
1977 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1978 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1979
1980 /// Apply the conversion cost model to the given set of patterns, and return
1981 /// the smallest legalization depth of any of the patterns. See
1982 /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1983 unsigned applyCostModelToPatterns(
1984 LegalizationPatterns &patterns,
1985 DenseMap<OperationName, unsigned> &minOpPatternDepth,
1986 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1987
1988 /// The current set of patterns that have been applied.
1989 SmallPtrSet<const Pattern *, 8> appliedPatterns;
1990
1991 /// The legalization information provided by the target.
1992 const ConversionTarget &target;
1993
1994 /// The pattern applicator to use for conversions.
1995 PatternApplicator applicator;
1996
1997 /// Dialect conversion configuration.
1998 const ConversionConfig &config;
1999};
2000} // namespace
2001
2002OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
2003 const FrozenRewritePatternSet &patterns,
2004 const ConversionConfig &config)
2005 : target(targetInfo), applicator(patterns), config(config) {
2006 // The set of patterns that can be applied to illegal operations to transform
2007 // them into legal ones.
2008 DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
2009 LegalizationPatterns anyOpLegalizerPatterns;
2010
2011 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2012 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2013}
2014
2015bool OperationLegalizer::isIllegal(Operation *op) const {
2016 return target.isIllegal(op);
2017}
2018
2019LogicalResult
2020OperationLegalizer::legalize(Operation *op,
2021 ConversionPatternRewriter &rewriter) {
2022#ifndef NDEBUG
2023 const char *logLineComment =
2024 "//===-------------------------------------------===//\n";
2025
2026 auto &logger = rewriter.getImpl().logger;
2027#endif
2028 LLVM_DEBUG({
2029 logger.getOStream() << "\n";
2030 logger.startLine() << logLineComment;
2031 logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
2032 << op << ") {\n";
2033 logger.indent();
2034
2035 // If the operation has no regions, just print it here.
2036 if (op->getNumRegions() == 0) {
2037 op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
2038 logger.getOStream() << "\n\n";
2039 }
2040 });
2041
2042 // Check if this operation is legal on the target.
2043 if (auto legalityInfo = target.isLegal(op)) {
2044 LLVM_DEBUG({
2045 logSuccess(
2046 logger, "operation marked legal by the target{0}",
2047 legalityInfo->isRecursivelyLegal
2048 ? "; NOTE: operation is recursively legal; skipping internals"
2049 : "");
2050 logger.startLine() << logLineComment;
2051 });
2052
2053 // If this operation is recursively legal, mark its children as ignored so
2054 // that we don't consider them for legalization.
2055 if (legalityInfo->isRecursivelyLegal) {
2056 op->walk(callback: [&](Operation *nested) {
2057 if (op != nested)
2058 rewriter.getImpl().ignoredOps.insert(X: nested);
2059 });
2060 }
2061
2062 return success();
2063 }
2064
2065 // Check to see if the operation is ignored and doesn't need to be converted.
2066 if (rewriter.getImpl().isOpIgnored(op)) {
2067 LLVM_DEBUG({
2068 logSuccess(logger, "operation marked 'ignored' during conversion");
2069 logger.startLine() << logLineComment;
2070 });
2071 return success();
2072 }
2073
2074 // If the operation isn't legal, try to fold it in-place.
2075 // TODO: Should we always try to do this, even if the op is
2076 // already legal?
2077 if (succeeded(Result: legalizeWithFold(op, rewriter))) {
2078 LLVM_DEBUG({
2079 logSuccess(logger, "operation was folded");
2080 logger.startLine() << logLineComment;
2081 });
2082 return success();
2083 }
2084
2085 // Otherwise, we need to apply a legalization pattern to this operation.
2086 if (succeeded(Result: legalizeWithPattern(op, rewriter))) {
2087 LLVM_DEBUG({
2088 logSuccess(logger, "");
2089 logger.startLine() << logLineComment;
2090 });
2091 return success();
2092 }
2093
2094 LLVM_DEBUG({
2095 logFailure(logger, "no matched legalization pattern");
2096 logger.startLine() << logLineComment;
2097 });
2098 return failure();
2099}
2100
2101LogicalResult
2102OperationLegalizer::legalizeWithFold(Operation *op,
2103 ConversionPatternRewriter &rewriter) {
2104 auto &rewriterImpl = rewriter.getImpl();
2105 LLVM_DEBUG({
2106 rewriterImpl.logger.startLine() << "* Fold {\n";
2107 rewriterImpl.logger.indent();
2108 });
2109 (void)rewriterImpl;
2110
2111 // Try to fold the operation.
2112 SmallVector<Value, 2> replacementValues;
2113 SmallVector<Operation *, 2> newOps;
2114 rewriter.setInsertionPoint(op);
2115 if (failed(Result: rewriter.tryFold(op, results&: replacementValues, materializedConstants: &newOps))) {
2116 LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2117 return failure();
2118 }
2119
2120 // An empty list of replacement values indicates that the fold was in-place.
2121 // As the operation changed, a new legalization needs to be attempted.
2122 if (replacementValues.empty())
2123 return legalize(op, rewriter);
2124
2125 // Recursively legalize any new constant operations.
2126 for (Operation *newOp : newOps) {
2127 if (failed(Result: legalize(op: newOp, rewriter))) {
2128 LLVM_DEBUG(logFailure(rewriterImpl.logger,
2129 "failed to legalize generated constant '{0}'",
2130 newOp->getName()));
2131 // Legalization failed: erase all materialized constants.
2132 for (Operation *op : newOps)
2133 rewriter.eraseOp(op);
2134 return failure();
2135 }
2136 }
2137
2138 // Insert a replacement for 'op' with the folded replacement values.
2139 rewriter.replaceOp(op, newValues: replacementValues);
2140
2141 LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2142 return success();
2143}
2144
2145LogicalResult
2146OperationLegalizer::legalizeWithPattern(Operation *op,
2147 ConversionPatternRewriter &rewriter) {
2148 auto &rewriterImpl = rewriter.getImpl();
2149
2150 // Functor that returns if the given pattern may be applied.
2151 auto canApply = [&](const Pattern &pattern) {
2152 bool canApply = canApplyPattern(op, pattern, rewriter);
2153 if (canApply && config.listener)
2154 config.listener->notifyPatternBegin(pattern, op);
2155 return canApply;
2156 };
2157
2158 // Functor that cleans up the rewriter state after a pattern failed to match.
2159 RewriterState curState = rewriterImpl.getCurrentState();
2160 auto onFailure = [&](const Pattern &pattern) {
2161 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2162 LLVM_DEBUG({
2163 logFailure(rewriterImpl.logger, "pattern failed to match");
2164 if (rewriterImpl.config.notifyCallback) {
2165 Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
2166 diag << "Failed to apply pattern \"" << pattern.getDebugName()
2167 << "\" on op:\n"
2168 << *op;
2169 rewriterImpl.config.notifyCallback(diag);
2170 }
2171 });
2172 if (config.listener)
2173 config.listener->notifyPatternEnd(pattern, status: failure());
2174 rewriterImpl.resetState(state: curState, patternName: pattern.getDebugName());
2175 appliedPatterns.erase(Ptr: &pattern);
2176 };
2177
2178 // Functor that performs additional legalization when a pattern is
2179 // successfully applied.
2180 auto onSuccess = [&](const Pattern &pattern) {
2181 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2182 auto result = legalizePatternResult(op, pattern, rewriter, curState);
2183 appliedPatterns.erase(Ptr: &pattern);
2184 if (failed(Result: result)) {
2185 if (!rewriterImpl.config.allowPatternRollback)
2186 op->emitError(message: "pattern '")
2187 << pattern.getDebugName()
2188 << "' produced IR that could not be legalized";
2189 rewriterImpl.resetState(state: curState, patternName: pattern.getDebugName());
2190 }
2191 if (config.listener)
2192 config.listener->notifyPatternEnd(pattern, status: result);
2193 return result;
2194 };
2195
2196 // Try to match and rewrite a pattern on this operation.
2197 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2198 onSuccess);
2199}
2200
2201bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
2202 ConversionPatternRewriter &rewriter) {
2203 LLVM_DEBUG({
2204 auto &os = rewriter.getImpl().logger;
2205 os.getOStream() << "\n";
2206 os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2207 llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2208 os.getOStream() << ")' {\n";
2209 os.indent();
2210 });
2211
2212 // Ensure that we don't cycle by not allowing the same pattern to be
2213 // applied twice in the same recursion stack if it is not known to be safe.
2214 if (!pattern.hasBoundedRewriteRecursion() &&
2215 !appliedPatterns.insert(Ptr: &pattern).second) {
2216 LLVM_DEBUG(
2217 logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2218 return false;
2219 }
2220 return true;
2221}
2222
2223LogicalResult
2224OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2225 ConversionPatternRewriter &rewriter,
2226 RewriterState &curState) {
2227 auto &impl = rewriter.getImpl();
2228 assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2229
2230#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2231 // Check that the root was either replaced or updated in place.
2232 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2233 auto replacedRoot = [&] {
2234 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2235 };
2236 auto updatedRootInPlace = [&] {
2237 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2238 };
2239 if (!replacedRoot() && !updatedRootInPlace())
2240 llvm::report_fatal_error("expected pattern to replace the root operation");
2241#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2242
2243 // Legalize each of the actions registered during application.
2244 RewriterState newState = impl.getCurrentState();
2245 if (failed(Result: legalizePatternBlockRewrites(op, rewriter, impl, state&: curState,
2246 newState)) ||
2247 failed(Result: legalizePatternRootUpdates(rewriter, impl, state&: curState, newState)) ||
2248 failed(Result: legalizePatternCreatedOperations(rewriter, impl, state&: curState,
2249 newState))) {
2250 return failure();
2251 }
2252
2253 LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2254 return success();
2255}
2256
2257LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2258 Operation *op, ConversionPatternRewriter &rewriter,
2259 ConversionPatternRewriterImpl &impl, RewriterState &state,
2260 RewriterState &newState) {
2261 SmallPtrSet<Operation *, 16> operationsToIgnore;
2262
2263 // If the pattern moved or created any blocks, make sure the types of block
2264 // arguments get legalized.
2265 for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2266 BlockRewrite *rewrite = dyn_cast<BlockRewrite>(Val: impl.rewrites[i].get());
2267 if (!rewrite)
2268 continue;
2269 Block *block = rewrite->getBlock();
2270 if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2271 ReplaceBlockArgRewrite>(Val: rewrite))
2272 continue;
2273 // Only check blocks outside of the current operation.
2274 Operation *parentOp = block->getParentOp();
2275 if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2276 continue;
2277
2278 // If the region of the block has a type converter, try to convert the block
2279 // directly.
2280 if (auto *converter = impl.regionToConverter.lookup(Val: block->getParent())) {
2281 std::optional<TypeConverter::SignatureConversion> conversion =
2282 converter->convertBlockSignature(block);
2283 if (!conversion) {
2284 LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2285 "block"));
2286 return failure();
2287 }
2288 impl.applySignatureConversion(rewriter, block, converter, signatureConversion&: *conversion);
2289 continue;
2290 }
2291
2292 // Otherwise, check that this operation isn't one generated by this pattern.
2293 // This is because we will attempt to legalize the parent operation, and
2294 // blocks in regions created by this pattern will already be legalized later
2295 // on. If we haven't built the set yet, build it now.
2296 if (operationsToIgnore.empty()) {
2297 for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
2298 ++i) {
2299 auto *createOp =
2300 dyn_cast<CreateOperationRewrite>(Val: impl.rewrites[i].get());
2301 if (!createOp)
2302 continue;
2303 operationsToIgnore.insert(Ptr: createOp->getOperation());
2304 }
2305 }
2306
2307 // If this operation should be considered for re-legalization, try it.
2308 if (operationsToIgnore.insert(Ptr: parentOp).second &&
2309 failed(Result: legalize(op: parentOp, rewriter))) {
2310 LLVM_DEBUG(logFailure(impl.logger,
2311 "operation '{0}'({1}) became illegal after rewrite",
2312 parentOp->getName(), parentOp));
2313 return failure();
2314 }
2315 }
2316 return success();
2317}
2318
2319LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2320 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
2321 RewriterState &state, RewriterState &newState) {
2322 for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2323 auto *createOp = dyn_cast<CreateOperationRewrite>(Val: impl.rewrites[i].get());
2324 if (!createOp)
2325 continue;
2326 Operation *op = createOp->getOperation();
2327 if (failed(Result: legalize(op, rewriter))) {
2328 LLVM_DEBUG(logFailure(impl.logger,
2329 "failed to legalize generated operation '{0}'({1})",
2330 op->getName(), op));
2331 return failure();
2332 }
2333 }
2334 return success();
2335}
2336
2337LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2338 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
2339 RewriterState &state, RewriterState &newState) {
2340 for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2341 auto *rewrite = dyn_cast<ModifyOperationRewrite>(Val: impl.rewrites[i].get());
2342 if (!rewrite)
2343 continue;
2344 Operation *op = rewrite->getOperation();
2345 if (failed(Result: legalize(op, rewriter))) {
2346 LLVM_DEBUG(logFailure(
2347 impl.logger, "failed to legalize operation updated in-place '{0}'",
2348 op->getName()));
2349 return failure();
2350 }
2351 }
2352 return success();
2353}
2354
2355//===----------------------------------------------------------------------===//
2356// Cost Model
2357//===----------------------------------------------------------------------===//
2358
2359void OperationLegalizer::buildLegalizationGraph(
2360 LegalizationPatterns &anyOpLegalizerPatterns,
2361 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2362 // A mapping between an operation and a set of operations that can be used to
2363 // generate it.
2364 DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps;
2365 // A mapping between an operation and any currently invalid patterns it has.
2366 DenseMap<OperationName, SmallPtrSet<const Pattern *, 2>> invalidPatterns;
2367 // A worklist of patterns to consider for legality.
2368 SetVector<const Pattern *> patternWorklist;
2369
2370 // Build the mapping from operations to the parent ops that may generate them.
2371 applicator.walkAllPatterns(walk: [&](const Pattern &pattern) {
2372 std::optional<OperationName> root = pattern.getRootKind();
2373
2374 // If the pattern has no specific root, we can't analyze the relationship
2375 // between the root op and generated operations. Given that, add all such
2376 // patterns to the legalization set.
2377 if (!root) {
2378 anyOpLegalizerPatterns.push_back(Elt: &pattern);
2379 return;
2380 }
2381
2382 // Skip operations that are always known to be legal.
2383 if (target.getOpAction(op: *root) == LegalizationAction::Legal)
2384 return;
2385
2386 // Add this pattern to the invalid set for the root op and record this root
2387 // as a parent for any generated operations.
2388 invalidPatterns[*root].insert(Ptr: &pattern);
2389 for (auto op : pattern.getGeneratedOps())
2390 parentOps[op].insert(Ptr: *root);
2391
2392 // Add this pattern to the worklist.
2393 patternWorklist.insert(X: &pattern);
2394 });
2395
2396 // If there are any patterns that don't have a specific root kind, we can't
2397 // make direct assumptions about what operations will never be legalized.
2398 // Note: Technically we could, but it would require an analysis that may
2399 // recurse into itself. It would be better to perform this kind of filtering
2400 // at a higher level than here anyways.
2401 if (!anyOpLegalizerPatterns.empty()) {
2402 for (const Pattern *pattern : patternWorklist)
2403 legalizerPatterns[*pattern->getRootKind()].push_back(Elt: pattern);
2404 return;
2405 }
2406
2407 while (!patternWorklist.empty()) {
2408 auto *pattern = patternWorklist.pop_back_val();
2409
2410 // Check to see if any of the generated operations are invalid.
2411 if (llvm::any_of(Range: pattern->getGeneratedOps(), P: [&](OperationName op) {
2412 std::optional<LegalizationAction> action = target.getOpAction(op);
2413 return !legalizerPatterns.count(Val: op) &&
2414 (!action || action == LegalizationAction::Illegal);
2415 }))
2416 continue;
2417
2418 // Otherwise, if all of the generated operation are valid, this op is now
2419 // legal so add all of the child patterns to the worklist.
2420 legalizerPatterns[*pattern->getRootKind()].push_back(Elt: pattern);
2421 invalidPatterns[*pattern->getRootKind()].erase(Ptr: pattern);
2422
2423 // Add any invalid patterns of the parent operations to see if they have now
2424 // become legal.
2425 for (auto op : parentOps[*pattern->getRootKind()])
2426 patternWorklist.set_union(invalidPatterns[op]);
2427 }
2428}
2429
2430void OperationLegalizer::computeLegalizationGraphBenefit(
2431 LegalizationPatterns &anyOpLegalizerPatterns,
2432 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2433 // The smallest pattern depth, when legalizing an operation.
2434 DenseMap<OperationName, unsigned> minOpPatternDepth;
2435
2436 // For each operation that is transitively legal, compute a cost for it.
2437 for (auto &opIt : legalizerPatterns)
2438 if (!minOpPatternDepth.count(Val: opIt.first))
2439 computeOpLegalizationDepth(op: opIt.first, minOpPatternDepth,
2440 legalizerPatterns);
2441
2442 // Apply the cost model to the patterns that can match any operation. Those
2443 // with a specific operation type are already resolved when computing the op
2444 // legalization depth.
2445 if (!anyOpLegalizerPatterns.empty())
2446 applyCostModelToPatterns(patterns&: anyOpLegalizerPatterns, minOpPatternDepth,
2447 legalizerPatterns);
2448
2449 // Apply a cost model to the pattern applicator. We order patterns first by
2450 // depth then benefit. `legalizerPatterns` contains per-op patterns by
2451 // decreasing benefit.
2452 applicator.applyCostModel(model: [&](const Pattern &pattern) {
2453 ArrayRef<const Pattern *> orderedPatternList;
2454 if (std::optional<OperationName> rootName = pattern.getRootKind())
2455 orderedPatternList = legalizerPatterns[*rootName];
2456 else
2457 orderedPatternList = anyOpLegalizerPatterns;
2458
2459 // If the pattern is not found, then it was removed and cannot be matched.
2460 auto *it = llvm::find(Range&: orderedPatternList, Val: &pattern);
2461 if (it == orderedPatternList.end())
2462 return PatternBenefit::impossibleToMatch();
2463
2464 // Patterns found earlier in the list have higher benefit.
2465 return PatternBenefit(std::distance(first: it, last: orderedPatternList.end()));
2466 });
2467}
2468
2469unsigned OperationLegalizer::computeOpLegalizationDepth(
2470 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2471 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2472 // Check for existing depth.
2473 auto depthIt = minOpPatternDepth.find(Val: op);
2474 if (depthIt != minOpPatternDepth.end())
2475 return depthIt->second;
2476
2477 // If a mapping for this operation does not exist, then this operation
2478 // is always legal. Return 0 as the depth for a directly legal operation.
2479 auto opPatternsIt = legalizerPatterns.find(Val: op);
2480 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2481 return 0u;
2482
2483 // Record this initial depth in case we encounter this op again when
2484 // recursively computing the depth.
2485 minOpPatternDepth.try_emplace(Key: op, Args: std::numeric_limits<unsigned>::max());
2486
2487 // Apply the cost model to the operation patterns, and update the minimum
2488 // depth.
2489 unsigned minDepth = applyCostModelToPatterns(
2490 patterns&: opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2491 minOpPatternDepth[op] = minDepth;
2492 return minDepth;
2493}
2494
2495unsigned OperationLegalizer::applyCostModelToPatterns(
2496 LegalizationPatterns &patterns,
2497 DenseMap<OperationName, unsigned> &minOpPatternDepth,
2498 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2499 unsigned minDepth = std::numeric_limits<unsigned>::max();
2500
2501 // Compute the depth for each pattern within the set.
2502 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2503 patternsByDepth.reserve(N: patterns.size());
2504 for (const Pattern *pattern : patterns) {
2505 unsigned depth = 1;
2506 for (auto generatedOp : pattern->getGeneratedOps()) {
2507 unsigned generatedOpDepth = computeOpLegalizationDepth(
2508 op: generatedOp, minOpPatternDepth, legalizerPatterns);
2509 depth = std::max(a: depth, b: generatedOpDepth + 1);
2510 }
2511 patternsByDepth.emplace_back(Args&: pattern, Args&: depth);
2512
2513 // Update the minimum depth of the pattern list.
2514 minDepth = std::min(a: minDepth, b: depth);
2515 }
2516
2517 // If the operation only has one legalization pattern, there is no need to
2518 // sort them.
2519 if (patternsByDepth.size() == 1)
2520 return minDepth;
2521
2522 // Sort the patterns by those likely to be the most beneficial.
2523 llvm::stable_sort(Range&: patternsByDepth,
2524 C: [](const std::pair<const Pattern *, unsigned> &lhs,
2525 const std::pair<const Pattern *, unsigned> &rhs) {
2526 // First sort by the smaller pattern legalization
2527 // depth.
2528 if (lhs.second != rhs.second)
2529 return lhs.second < rhs.second;
2530
2531 // Then sort by the larger pattern benefit.
2532 auto lhsBenefit = lhs.first->getBenefit();
2533 auto rhsBenefit = rhs.first->getBenefit();
2534 return lhsBenefit > rhsBenefit;
2535 });
2536
2537 // Update the legalization pattern to use the new sorted list.
2538 patterns.clear();
2539 for (auto &patternIt : patternsByDepth)
2540 patterns.push_back(Elt: patternIt.first);
2541 return minDepth;
2542}
2543
2544//===----------------------------------------------------------------------===//
2545// OperationConverter
2546//===----------------------------------------------------------------------===//
2547namespace {
2548enum OpConversionMode {
2549 /// In this mode, the conversion will ignore failed conversions to allow
2550 /// illegal operations to co-exist in the IR.
2551 Partial,
2552
2553 /// In this mode, all operations must be legal for the given target for the
2554 /// conversion to succeed.
2555 Full,
2556
2557 /// In this mode, operations are analyzed for legality. No actual rewrites are
2558 /// applied to the operations on success.
2559 Analysis,
2560};
2561} // namespace
2562
2563namespace mlir {
2564// This class converts operations to a given conversion target via a set of
2565// rewrite patterns. The conversion behaves differently depending on the
2566// conversion mode.
2567struct OperationConverter {
2568 explicit OperationConverter(const ConversionTarget &target,
2569 const FrozenRewritePatternSet &patterns,
2570 const ConversionConfig &config,
2571 OpConversionMode mode)
2572 : config(config), opLegalizer(target, patterns, this->config),
2573 mode(mode) {}
2574
2575 /// Converts the given operations to the conversion target.
2576 LogicalResult convertOperations(ArrayRef<Operation *> ops);
2577
2578private:
2579 /// Converts an operation with the given rewriter.
2580 LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2581
2582 /// Dialect conversion configuration.
2583 ConversionConfig config;
2584
2585 /// The legalizer to use when converting operations.
2586 OperationLegalizer opLegalizer;
2587
2588 /// The conversion mode to use when legalizing operations.
2589 OpConversionMode mode;
2590};
2591} // namespace mlir
2592
2593LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2594 Operation *op) {
2595 // Legalize the given operation.
2596 if (failed(Result: opLegalizer.legalize(op, rewriter))) {
2597 // Handle the case of a failed conversion for each of the different modes.
2598 // Full conversions expect all operations to be converted.
2599 if (mode == OpConversionMode::Full)
2600 return op->emitError()
2601 << "failed to legalize operation '" << op->getName() << "'";
2602 // Partial conversions allow conversions to fail iff the operation was not
2603 // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2604 // set, non-legalizable ops are added to that set.
2605 if (mode == OpConversionMode::Partial) {
2606 if (opLegalizer.isIllegal(op))
2607 return op->emitError()
2608 << "failed to legalize operation '" << op->getName()
2609 << "' that was explicitly marked illegal";
2610 if (config.unlegalizedOps)
2611 config.unlegalizedOps->insert(V: op);
2612 }
2613 } else if (mode == OpConversionMode::Analysis) {
2614 // Analysis conversions don't fail if any operations fail to legalize,
2615 // they are only interested in the operations that were successfully
2616 // legalized.
2617 if (config.legalizableOps)
2618 config.legalizableOps->insert(V: op);
2619 }
2620 return success();
2621}
2622
2623static LogicalResult
2624legalizeUnresolvedMaterialization(RewriterBase &rewriter,
2625 UnresolvedMaterializationRewrite *rewrite) {
2626 UnrealizedConversionCastOp op = rewrite->getOperation();
2627 assert(!op.use_empty() &&
2628 "expected that dead materializations have already been DCE'd");
2629 Operation::operand_range inputOperands = op.getOperands();
2630
2631 // Try to materialize the conversion.
2632 if (const TypeConverter *converter = rewrite->getConverter()) {
2633 rewriter.setInsertionPoint(op);
2634 SmallVector<Value> newMaterialization;
2635 switch (rewrite->getMaterializationKind()) {
2636 case MaterializationKind::Target:
2637 newMaterialization = converter->materializeTargetConversion(
2638 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
2639 rewrite->getOriginalType());
2640 break;
2641 case MaterializationKind::Source:
2642 assert(op->getNumResults() == 1 && "expected single result");
2643 Value sourceMat = converter->materializeSourceConversion(
2644 builder&: rewriter, loc: op->getLoc(), resultType: op.getResultTypes().front(), inputs: inputOperands);
2645 if (sourceMat)
2646 newMaterialization.push_back(Elt: sourceMat);
2647 break;
2648 }
2649 if (!newMaterialization.empty()) {
2650#ifndef NDEBUG
2651 ValueRange newMaterializationRange(newMaterialization);
2652 assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
2653 "materialization callback produced value of incorrect type");
2654#endif // NDEBUG
2655 rewriter.replaceOp(op, newMaterialization);
2656 return success();
2657 }
2658 }
2659
2660 InFlightDiagnostic diag = op->emitError()
2661 << "failed to legalize unresolved materialization "
2662 "from ("
2663 << inputOperands.getTypes() << ") to ("
2664 << op.getResultTypes()
2665 << ") that remained live after conversion";
2666 diag.attachNote(noteLoc: op->getUsers().begin()->getLoc())
2667 << "see existing live user here: " << *op->getUsers().begin();
2668 return failure();
2669}
2670
2671LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2672 if (ops.empty())
2673 return success();
2674 const ConversionTarget &target = opLegalizer.getTarget();
2675
2676 // Compute the set of operations and blocks to convert.
2677 SmallVector<Operation *> toConvert;
2678 for (auto *op : ops) {
2679 op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
2680 callback: [&](Operation *op) {
2681 toConvert.push_back(Elt: op);
2682 // Don't check this operation's children for conversion if the
2683 // operation is recursively legal.
2684 auto legalityInfo = target.isLegal(op);
2685 if (legalityInfo && legalityInfo->isRecursivelyLegal)
2686 return WalkResult::skip();
2687 return WalkResult::advance();
2688 });
2689 }
2690
2691 // Convert each operation and discard rewrites on failure.
2692 ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
2693 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2694
2695 for (auto *op : toConvert) {
2696 if (failed(Result: convert(rewriter, op))) {
2697 // Dialect conversion failed.
2698 if (rewriterImpl.config.allowPatternRollback) {
2699 // Rollback is allowed: restore the original IR.
2700 rewriterImpl.undoRewrites();
2701 } else {
2702 // Rollback is not allowed: apply all modifications that have been
2703 // performed so far.
2704 rewriterImpl.applyRewrites();
2705 }
2706 return failure();
2707 }
2708 }
2709
2710 // After a successful conversion, apply rewrites.
2711 rewriterImpl.applyRewrites();
2712
2713 // Gather all unresolved materializations.
2714 SmallVector<UnrealizedConversionCastOp> allCastOps;
2715 const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
2716 &materializations = rewriterImpl.unresolvedMaterializations;
2717 for (auto it : materializations) {
2718 if (rewriterImpl.eraseRewriter.wasErased(it.first))
2719 continue;
2720 allCastOps.push_back(it.first);
2721 }
2722
2723 // Reconcile all UnrealizedConversionCastOps that were inserted by the
2724 // dialect conversion frameworks. (Not the one that were inserted by
2725 // patterns.)
2726 SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2727 reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
2728
2729 // Try to legalize all unresolved materializations.
2730 if (config.buildMaterializations) {
2731 IRRewriter rewriter(rewriterImpl.context, config.listener);
2732 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2733 auto it = materializations.find(castOp);
2734 assert(it != materializations.end() && "inconsistent state");
2735 if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
2736 return failure();
2737 }
2738 }
2739
2740 return success();
2741}
2742
2743//===----------------------------------------------------------------------===//
2744// Reconcile Unrealized Casts
2745//===----------------------------------------------------------------------===//
2746
2747void mlir::reconcileUnrealizedCasts(
2748 ArrayRef<UnrealizedConversionCastOp> castOps,
2749 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2750 SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
2751 // This set is maintained only if `remainingCastOps` is provided.
2752 DenseSet<Operation *> erasedOps;
2753
2754 // Helper function that adds all operands to the worklist that are an
2755 // unrealized_conversion_cast op result.
2756 auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2757 for (Value v : castOp.getInputs())
2758 if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2759 worklist.insert(inputCastOp);
2760 };
2761
2762 // Helper function that return the unrealized_conversion_cast op that
2763 // defines all inputs of the given op (in the same order). Return "nullptr"
2764 // if there is no such op.
2765 auto getInputCast =
2766 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2767 if (castOp.getInputs().empty())
2768 return {};
2769 auto inputCastOp =
2770 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2771 if (!inputCastOp)
2772 return {};
2773 if (inputCastOp.getOutputs() != castOp.getInputs())
2774 return {};
2775 return inputCastOp;
2776 };
2777
2778 // Process ops in the worklist bottom-to-top.
2779 while (!worklist.empty()) {
2780 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2781 if (castOp->use_empty()) {
2782 // DCE: If the op has no users, erase it. Add the operands to the
2783 // worklist to find additional DCE opportunities.
2784 enqueueOperands(castOp);
2785 if (remainingCastOps)
2786 erasedOps.insert(castOp.getOperation());
2787 castOp->erase();
2788 continue;
2789 }
2790
2791 // Traverse the chain of input cast ops to see if an op with the same
2792 // input types can be found.
2793 UnrealizedConversionCastOp nextCast = castOp;
2794 while (nextCast) {
2795 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2796 // Found a cast where the input types match the output types of the
2797 // matched op. We can directly use those inputs and the matched op can
2798 // be removed.
2799 enqueueOperands(castOp);
2800 castOp.replaceAllUsesWith(nextCast.getInputs());
2801 if (remainingCastOps)
2802 erasedOps.insert(castOp.getOperation());
2803 castOp->erase();
2804 break;
2805 }
2806 nextCast = getInputCast(nextCast);
2807 }
2808 }
2809
2810 if (remainingCastOps)
2811 for (UnrealizedConversionCastOp op : castOps)
2812 if (!erasedOps.contains(op.getOperation()))
2813 remainingCastOps->push_back(op);
2814}
2815
2816//===----------------------------------------------------------------------===//
2817// Type Conversion
2818//===----------------------------------------------------------------------===//
2819
2820void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
2821 ArrayRef<Type> types) {
2822 assert(!types.empty() && "expected valid types");
2823 remapInput(origInputNo, /*newInputNo=*/argTypes.size(), newInputCount: types.size());
2824 addInputs(types);
2825}
2826
2827void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
2828 assert(!types.empty() &&
2829 "1->0 type remappings don't need to be added explicitly");
2830 argTypes.append(in_start: types.begin(), in_end: types.end());
2831}
2832
2833void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2834 unsigned newInputNo,
2835 unsigned newInputCount) {
2836 assert(!remappedInputs[origInputNo] && "input has already been remapped");
2837 assert(newInputCount != 0 && "expected valid input count");
2838 remappedInputs[origInputNo] =
2839 InputMapping{.inputNo: newInputNo, .size: newInputCount, /*replacementValues=*/{}};
2840}
2841
2842void TypeConverter::SignatureConversion::remapInput(
2843 unsigned origInputNo, ArrayRef<Value> replacements) {
2844 assert(!remappedInputs[origInputNo] && "input has already been remapped");
2845 remappedInputs[origInputNo] = InputMapping{
2846 .inputNo: origInputNo, /*size=*/0,
2847 .replacementValues: SmallVector<Value, 1>(replacements.begin(), replacements.end())};
2848}
2849
2850LogicalResult TypeConverter::convertType(Type t,
2851 SmallVectorImpl<Type> &results) const {
2852 assert(t && "expected non-null type");
2853
2854 {
2855 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2856 std::defer_lock);
2857 if (t.getContext()->isMultithreadingEnabled())
2858 cacheReadLock.lock();
2859 auto existingIt = cachedDirectConversions.find(Val: t);
2860 if (existingIt != cachedDirectConversions.end()) {
2861 if (existingIt->second)
2862 results.push_back(Elt: existingIt->second);
2863 return success(IsSuccess: existingIt->second != nullptr);
2864 }
2865 auto multiIt = cachedMultiConversions.find(Val: t);
2866 if (multiIt != cachedMultiConversions.end()) {
2867 results.append(in_start: multiIt->second.begin(), in_end: multiIt->second.end());
2868 return success();
2869 }
2870 }
2871 // Walk the added converters in reverse order to apply the most recently
2872 // registered first.
2873 size_t currentCount = results.size();
2874
2875 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2876 std::defer_lock);
2877
2878 for (const ConversionCallbackFn &converter : llvm::reverse(C: conversions)) {
2879 if (std::optional<LogicalResult> result = converter(t, results)) {
2880 if (t.getContext()->isMultithreadingEnabled())
2881 cacheWriteLock.lock();
2882 if (!succeeded(Result: *result)) {
2883 assert(results.size() == currentCount &&
2884 "failed type conversion should not change results");
2885 cachedDirectConversions.try_emplace(Key: t, Args: nullptr);
2886 return failure();
2887 }
2888 auto newTypes = ArrayRef<Type>(results).drop_front(N: currentCount);
2889 if (newTypes.size() == 1)
2890 cachedDirectConversions.try_emplace(Key: t, Args: newTypes.front());
2891 else
2892 cachedMultiConversions.try_emplace(Key: t, Args: llvm::to_vector<2>(Range&: newTypes));
2893 return success();
2894 } else {
2895 assert(results.size() == currentCount &&
2896 "failed type conversion should not change results");
2897 }
2898 }
2899 return failure();
2900}
2901
2902Type TypeConverter::convertType(Type t) const {
2903 // Use the multi-type result version to convert the type.
2904 SmallVector<Type, 1> results;
2905 if (failed(Result: convertType(t, results)))
2906 return nullptr;
2907
2908 // Check to ensure that only one type was produced.
2909 return results.size() == 1 ? results.front() : nullptr;
2910}
2911
2912LogicalResult
2913TypeConverter::convertTypes(TypeRange types,
2914 SmallVectorImpl<Type> &results) const {
2915 for (Type type : types)
2916 if (failed(Result: convertType(t: type, results)))
2917 return failure();
2918 return success();
2919}
2920
2921bool TypeConverter::isLegal(Type type) const {
2922 return convertType(t: type) == type;
2923}
2924bool TypeConverter::isLegal(Operation *op) const {
2925 return isLegal(range: op->getOperandTypes()) && isLegal(range: op->getResultTypes());
2926}
2927
2928bool TypeConverter::isLegal(Region *region) const {
2929 return llvm::all_of(Range&: *region, P: [this](Block &block) {
2930 return isLegal(range: block.getArgumentTypes());
2931 });
2932}
2933
2934bool TypeConverter::isSignatureLegal(FunctionType ty) const {
2935 return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2936}
2937
2938LogicalResult
2939TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
2940 SignatureConversion &result) const {
2941 // Try to convert the given input type.
2942 SmallVector<Type, 1> convertedTypes;
2943 if (failed(Result: convertType(t: type, results&: convertedTypes)))
2944 return failure();
2945
2946 // If this argument is being dropped, there is nothing left to do.
2947 if (convertedTypes.empty())
2948 return success();
2949
2950 // Otherwise, add the new inputs.
2951 result.addInputs(origInputNo: inputNo, types: convertedTypes);
2952 return success();
2953}
2954LogicalResult
2955TypeConverter::convertSignatureArgs(TypeRange types,
2956 SignatureConversion &result,
2957 unsigned origInputOffset) const {
2958 for (unsigned i = 0, e = types.size(); i != e; ++i)
2959 if (failed(Result: convertSignatureArg(inputNo: origInputOffset + i, type: types[i], result)))
2960 return failure();
2961 return success();
2962}
2963
2964Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
2965 Location loc, Type resultType,
2966 ValueRange inputs) const {
2967 for (const SourceMaterializationCallbackFn &fn :
2968 llvm::reverse(C: sourceMaterializations))
2969 if (Value result = fn(builder, resultType, inputs, loc))
2970 return result;
2971 return nullptr;
2972}
2973
2974Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
2975 Location loc, Type resultType,
2976 ValueRange inputs,
2977 Type originalType) const {
2978 SmallVector<Value> result = materializeTargetConversion(
2979 builder, loc, resultType: TypeRange(resultType), inputs, originalType);
2980 if (result.empty())
2981 return nullptr;
2982 assert(result.size() == 1 && "expected single result");
2983 return result.front();
2984}
2985
2986SmallVector<Value> TypeConverter::materializeTargetConversion(
2987 OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
2988 Type originalType) const {
2989 for (const TargetMaterializationCallbackFn &fn :
2990 llvm::reverse(C: targetMaterializations)) {
2991 SmallVector<Value> result =
2992 fn(builder, resultTypes, inputs, loc, originalType);
2993 if (result.empty())
2994 continue;
2995 assert(TypeRange(ValueRange(result)) == resultTypes &&
2996 "callback produced incorrect number of values or values with "
2997 "incorrect types");
2998 return result;
2999 }
3000 return {};
3001}
3002
3003std::optional<TypeConverter::SignatureConversion>
3004TypeConverter::convertBlockSignature(Block *block) const {
3005 SignatureConversion conversion(block->getNumArguments());
3006 if (failed(Result: convertSignatureArgs(types: block->getArgumentTypes(), result&: conversion)))
3007 return std::nullopt;
3008 return conversion;
3009}
3010
3011//===----------------------------------------------------------------------===//
3012// Type attribute conversion
3013//===----------------------------------------------------------------------===//
3014TypeConverter::AttributeConversionResult
3015TypeConverter::AttributeConversionResult::result(Attribute attr) {
3016 return AttributeConversionResult(attr, resultTag);
3017}
3018
3019TypeConverter::AttributeConversionResult
3020TypeConverter::AttributeConversionResult::na() {
3021 return AttributeConversionResult(nullptr, naTag);
3022}
3023
3024TypeConverter::AttributeConversionResult
3025TypeConverter::AttributeConversionResult::abort() {
3026 return AttributeConversionResult(nullptr, abortTag);
3027}
3028
3029bool TypeConverter::AttributeConversionResult::hasResult() const {
3030 return impl.getInt() == resultTag;
3031}
3032
3033bool TypeConverter::AttributeConversionResult::isNa() const {
3034 return impl.getInt() == naTag;
3035}
3036
3037bool TypeConverter::AttributeConversionResult::isAbort() const {
3038 return impl.getInt() == abortTag;
3039}
3040
3041Attribute TypeConverter::AttributeConversionResult::getResult() const {
3042 assert(hasResult() && "Cannot get result from N/A or abort");
3043 return impl.getPointer();
3044}
3045
3046std::optional<Attribute>
3047TypeConverter::convertTypeAttribute(Type type, Attribute attr) const {
3048 for (const TypeAttributeConversionCallbackFn &fn :
3049 llvm::reverse(C: typeAttributeConversions)) {
3050 AttributeConversionResult res = fn(type, attr);
3051 if (res.hasResult())
3052 return res.getResult();
3053 if (res.isAbort())
3054 return std::nullopt;
3055 }
3056 return std::nullopt;
3057}
3058
3059//===----------------------------------------------------------------------===//
3060// FunctionOpInterfaceSignatureConversion
3061//===----------------------------------------------------------------------===//
3062
3063static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3064 const TypeConverter &typeConverter,
3065 ConversionPatternRewriter &rewriter) {
3066 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3067 if (!type)
3068 return failure();
3069
3070 // Convert the original function types.
3071 TypeConverter::SignatureConversion result(type.getNumInputs());
3072 SmallVector<Type, 1> newResults;
3073 if (failed(typeConverter.convertSignatureArgs(types: type.getInputs(), result)) ||
3074 failed(typeConverter.convertTypes(types: type.getResults(), results&: newResults)) ||
3075 failed(rewriter.convertRegionTypes(region: &funcOp.getFunctionBody(),
3076 converter: typeConverter, entryConversion: &result)))
3077 return failure();
3078
3079 // Update the function signature in-place.
3080 auto newType = FunctionType::get(rewriter.getContext(),
3081 result.getConvertedTypes(), newResults);
3082
3083 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3084
3085 return success();
3086}
3087
3088/// Create a default conversion pattern that rewrites the type signature of a
3089/// FunctionOpInterface op. This only supports ops which use FunctionType to
3090/// represent their type.
3091namespace {
3092struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3093 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3094 MLIRContext *ctx,
3095 const TypeConverter &converter)
3096 : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
3097
3098 LogicalResult
3099 matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3100 ConversionPatternRewriter &rewriter) const override {
3101 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3102 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3103 }
3104};
3105
3106struct AnyFunctionOpInterfaceSignatureConversion
3107 : public OpInterfaceConversionPattern<FunctionOpInterface> {
3108 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3109
3110 LogicalResult
3111 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3112 ConversionPatternRewriter &rewriter) const override {
3113 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3114 }
3115};
3116} // namespace
3117
3118FailureOr<Operation *>
3119mlir::convertOpResultTypes(Operation *op, ValueRange operands,
3120 const TypeConverter &converter,
3121 ConversionPatternRewriter &rewriter) {
3122 assert(op && "Invalid op");
3123 Location loc = op->getLoc();
3124 if (converter.isLegal(op))
3125 return rewriter.notifyMatchFailure(arg&: loc, msg: "op already legal");
3126
3127 OperationState newOp(loc, op->getName());
3128 newOp.addOperands(newOperands: operands);
3129
3130 SmallVector<Type> newResultTypes;
3131 if (failed(Result: converter.convertTypes(types: op->getResultTypes(), results&: newResultTypes)))
3132 return rewriter.notifyMatchFailure(arg&: loc, msg: "couldn't convert return types");
3133
3134 newOp.addTypes(newTypes: newResultTypes);
3135 newOp.addAttributes(newAttributes: op->getAttrs());
3136 return rewriter.create(state: newOp);
3137}
3138
3139void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3140 StringRef functionLikeOpName, RewritePatternSet &patterns,
3141 const TypeConverter &converter) {
3142 patterns.add<FunctionOpInterfaceSignatureConversion>(
3143 arg&: functionLikeOpName, args: patterns.getContext(), args: converter);
3144}
3145
3146void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3147 RewritePatternSet &patterns, const TypeConverter &converter) {
3148 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3149 arg: converter, args: patterns.getContext());
3150}
3151
3152//===----------------------------------------------------------------------===//
3153// ConversionTarget
3154//===----------------------------------------------------------------------===//
3155
3156void ConversionTarget::setOpAction(OperationName op,
3157 LegalizationAction action) {
3158 legalOperations[op].action = action;
3159}
3160
3161void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3162 LegalizationAction action) {
3163 for (StringRef dialect : dialectNames)
3164 legalDialects[dialect] = action;
3165}
3166
3167auto ConversionTarget::getOpAction(OperationName op) const
3168 -> std::optional<LegalizationAction> {
3169 std::optional<LegalizationInfo> info = getOpInfo(op);
3170 return info ? info->action : std::optional<LegalizationAction>();
3171}
3172
3173auto ConversionTarget::isLegal(Operation *op) const
3174 -> std::optional<LegalOpDetails> {
3175 std::optional<LegalizationInfo> info = getOpInfo(op: op->getName());
3176 if (!info)
3177 return std::nullopt;
3178
3179 // Returns true if this operation instance is known to be legal.
3180 auto isOpLegal = [&] {
3181 // Handle dynamic legality either with the provided legality function.
3182 if (info->action == LegalizationAction::Dynamic) {
3183 std::optional<bool> result = info->legalityFn(op);
3184 if (result)
3185 return *result;
3186 }
3187
3188 // Otherwise, the operation is only legal if it was marked 'Legal'.
3189 return info->action == LegalizationAction::Legal;
3190 };
3191 if (!isOpLegal())
3192 return std::nullopt;
3193
3194 // This operation is legal, compute any additional legality information.
3195 LegalOpDetails legalityDetails;
3196 if (info->isRecursivelyLegal) {
3197 auto legalityFnIt = opRecursiveLegalityFns.find(Val: op->getName());
3198 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3199 legalityDetails.isRecursivelyLegal =
3200 legalityFnIt->second(op).value_or(u: true);
3201 } else {
3202 legalityDetails.isRecursivelyLegal = true;
3203 }
3204 }
3205 return legalityDetails;
3206}
3207
3208bool ConversionTarget::isIllegal(Operation *op) const {
3209 std::optional<LegalizationInfo> info = getOpInfo(op: op->getName());
3210 if (!info)
3211 return false;
3212
3213 if (info->action == LegalizationAction::Dynamic) {
3214 std::optional<bool> result = info->legalityFn(op);
3215 if (!result)
3216 return false;
3217
3218 return !(*result);
3219 }
3220
3221 return info->action == LegalizationAction::Illegal;
3222}
3223
3224static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(
3225 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
3226 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
3227 if (!oldCallback)
3228 return newCallback;
3229
3230 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3231 Operation *op) -> std::optional<bool> {
3232 if (std::optional<bool> result = newCl(op))
3233 return *result;
3234
3235 return oldCl(op);
3236 };
3237 return chain;
3238}
3239
3240void ConversionTarget::setLegalityCallback(
3241 OperationName name, const DynamicLegalityCallbackFn &callback) {
3242 assert(callback && "expected valid legality callback");
3243 auto *infoIt = legalOperations.find(Key: name);
3244 assert(infoIt != legalOperations.end() &&
3245 infoIt->second.action == LegalizationAction::Dynamic &&
3246 "expected operation to already be marked as dynamically legal");
3247 infoIt->second.legalityFn =
3248 composeLegalityCallbacks(oldCallback: std::move(infoIt->second.legalityFn), newCallback: callback);
3249}
3250
3251void ConversionTarget::markOpRecursivelyLegal(
3252 OperationName name, const DynamicLegalityCallbackFn &callback) {
3253 auto *infoIt = legalOperations.find(Key: name);
3254 assert(infoIt != legalOperations.end() &&
3255 infoIt->second.action != LegalizationAction::Illegal &&
3256 "expected operation to already be marked as legal");
3257 infoIt->second.isRecursivelyLegal = true;
3258 if (callback)
3259 opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3260 oldCallback: std::move(opRecursiveLegalityFns[name]), newCallback: callback);
3261 else
3262 opRecursiveLegalityFns.erase(Val: name);
3263}
3264
3265void ConversionTarget::setLegalityCallback(
3266 ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3267 assert(callback && "expected valid legality callback");
3268 for (StringRef dialect : dialects)
3269 dialectLegalityFns[dialect] = composeLegalityCallbacks(
3270 oldCallback: std::move(dialectLegalityFns[dialect]), newCallback: callback);
3271}
3272
3273void ConversionTarget::setLegalityCallback(
3274 const DynamicLegalityCallbackFn &callback) {
3275 assert(callback && "expected valid legality callback");
3276 unknownLegalityFn = composeLegalityCallbacks(oldCallback: unknownLegalityFn, newCallback: callback);
3277}
3278
3279auto ConversionTarget::getOpInfo(OperationName op) const
3280 -> std::optional<LegalizationInfo> {
3281 // Check for info for this specific operation.
3282 const auto *it = legalOperations.find(Key: op);
3283 if (it != legalOperations.end())
3284 return it->second;
3285 // Check for info for the parent dialect.
3286 auto dialectIt = legalDialects.find(Key: op.getDialectNamespace());
3287 if (dialectIt != legalDialects.end()) {
3288 DynamicLegalityCallbackFn callback;
3289 auto dialectFn = dialectLegalityFns.find(Key: op.getDialectNamespace());
3290 if (dialectFn != dialectLegalityFns.end())
3291 callback = dialectFn->second;
3292 return LegalizationInfo{.action: dialectIt->second, /*isRecursivelyLegal=*/false,
3293 .legalityFn: callback};
3294 }
3295 // Otherwise, check if we mark unknown operations as dynamic.
3296 if (unknownLegalityFn)
3297 return LegalizationInfo{.action: LegalizationAction::Dynamic,
3298 /*isRecursivelyLegal=*/false, .legalityFn: unknownLegalityFn};
3299 return std::nullopt;
3300}
3301
3302#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3303//===----------------------------------------------------------------------===//
3304// PDL Configuration
3305//===----------------------------------------------------------------------===//
3306
3307void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
3308 auto &rewriterImpl =
3309 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3310 rewriterImpl.currentTypeConverter = getTypeConverter();
3311}
3312
3313void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
3314 auto &rewriterImpl =
3315 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3316 rewriterImpl.currentTypeConverter = nullptr;
3317}
3318
3319/// Remap the given value using the rewriter and the type converter in the
3320/// provided config.
3321static FailureOr<SmallVector<Value>>
3322pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) {
3323 SmallVector<Value> mappedValues;
3324 if (failed(Result: rewriter.getRemappedValues(keys: values, results&: mappedValues)))
3325 return failure();
3326 return std::move(mappedValues);
3327}
3328
3329void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
3330 patterns.getPDLPatterns().registerRewriteFunction(
3331 name: "convertValue",
3332 rewriteFn: [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3333 auto results = pdllConvertValues(
3334 rewriter&: static_cast<ConversionPatternRewriter &>(rewriter), values: value);
3335 if (failed(Result: results))
3336 return failure();
3337 return results->front();
3338 });
3339 patterns.getPDLPatterns().registerRewriteFunction(
3340 name: "convertValues", rewriteFn: [](PatternRewriter &rewriter, ValueRange values) {
3341 return pdllConvertValues(
3342 rewriter&: static_cast<ConversionPatternRewriter &>(rewriter), values);
3343 });
3344 patterns.getPDLPatterns().registerRewriteFunction(
3345 name: "convertType",
3346 rewriteFn: [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3347 auto &rewriterImpl =
3348 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3349 if (const TypeConverter *converter =
3350 rewriterImpl.currentTypeConverter) {
3351 if (Type newType = converter->convertType(t: type))
3352 return newType;
3353 return failure();
3354 }
3355 return type;
3356 });
3357 patterns.getPDLPatterns().registerRewriteFunction(
3358 name: "convertTypes",
3359 rewriteFn: [](PatternRewriter &rewriter,
3360 TypeRange types) -> FailureOr<SmallVector<Type>> {
3361 auto &rewriterImpl =
3362 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3363 const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3364 if (!converter)
3365 return SmallVector<Type>(types);
3366
3367 SmallVector<Type> remappedTypes;
3368 if (failed(Result: converter->convertTypes(types, results&: remappedTypes)))
3369 return failure();
3370 return std::move(remappedTypes);
3371 });
3372}
3373#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3374
3375//===----------------------------------------------------------------------===//
3376// Op Conversion Entry Points
3377//===----------------------------------------------------------------------===//
3378
3379//===----------------------------------------------------------------------===//
3380// Partial Conversion
3381//===----------------------------------------------------------------------===//
3382
3383LogicalResult mlir::applyPartialConversion(
3384 ArrayRef<Operation *> ops, const ConversionTarget &target,
3385 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3386 OperationConverter opConverter(target, patterns, config,
3387 OpConversionMode::Partial);
3388 return opConverter.convertOperations(ops);
3389}
3390LogicalResult
3391mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
3392 const FrozenRewritePatternSet &patterns,
3393 ConversionConfig config) {
3394 return applyPartialConversion(ops: llvm::ArrayRef(op), target, patterns, config);
3395}
3396
3397//===----------------------------------------------------------------------===//
3398// Full Conversion
3399//===----------------------------------------------------------------------===//
3400
3401LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
3402 const ConversionTarget &target,
3403 const FrozenRewritePatternSet &patterns,
3404 ConversionConfig config) {
3405 OperationConverter opConverter(target, patterns, config,
3406 OpConversionMode::Full);
3407 return opConverter.convertOperations(ops);
3408}
3409LogicalResult mlir::applyFullConversion(Operation *op,
3410 const ConversionTarget &target,
3411 const FrozenRewritePatternSet &patterns,
3412 ConversionConfig config) {
3413 return applyFullConversion(ops: llvm::ArrayRef(op), target, patterns, config);
3414}
3415
3416//===----------------------------------------------------------------------===//
3417// Analysis Conversion
3418//===----------------------------------------------------------------------===//
3419
3420/// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
3421/// op is a top-level module op (which is expected to be isolated from above),
3422/// return that op.
3423static Operation *findCommonAncestor(ArrayRef<Operation *> ops) {
3424 // Check if there is a top-level operation within `ops`. If so, return that
3425 // op.
3426 for (Operation *op : ops) {
3427 if (!op->getParentOp()) {
3428#ifndef NDEBUG
3429 assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
3430 "expected top-level op to be isolated from above");
3431 for (Operation *other : ops)
3432 assert(op->isAncestor(other) &&
3433 "expected ops to have a common ancestor");
3434#endif // NDEBUG
3435 return op;
3436 }
3437 }
3438
3439 // No top-level op. Find a common ancestor.
3440 Operation *commonAncestor =
3441 ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
3442 for (Operation *op : ops.drop_front()) {
3443 while (!commonAncestor->isProperAncestor(other: op)) {
3444 commonAncestor =
3445 commonAncestor->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
3446 assert(commonAncestor &&
3447 "expected to find a common isolated from above ancestor");
3448 }
3449 }
3450
3451 return commonAncestor;
3452}
3453
3454LogicalResult mlir::applyAnalysisConversion(
3455 ArrayRef<Operation *> ops, ConversionTarget &target,
3456 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3457#ifndef NDEBUG
3458 if (config.legalizableOps)
3459 assert(config.legalizableOps->empty() && "expected empty set");
3460#endif // NDEBUG
3461
3462 // Clone closted common ancestor that is isolated from above.
3463 Operation *commonAncestor = findCommonAncestor(ops);
3464 IRMapping mapping;
3465 Operation *clonedAncestor = commonAncestor->clone(mapper&: mapping);
3466 // Compute inverse IR mapping.
3467 DenseMap<Operation *, Operation *> inverseOperationMap;
3468 for (auto &it : mapping.getOperationMap())
3469 inverseOperationMap[it.second] = it.first;
3470
3471 // Convert the cloned operations. The original IR will remain unchanged.
3472 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
3473 C&: ops, F: [&](Operation *op) { return mapping.lookup(from: op); });
3474 OperationConverter opConverter(target, patterns, config,
3475 OpConversionMode::Analysis);
3476 LogicalResult status = opConverter.convertOperations(ops: opsToConvert);
3477
3478 // Remap `legalizableOps`, so that they point to the original ops and not the
3479 // cloned ops.
3480 if (config.legalizableOps) {
3481 DenseSet<Operation *> originalLegalizableOps;
3482 for (Operation *op : *config.legalizableOps)
3483 originalLegalizableOps.insert(V: inverseOperationMap[op]);
3484 *config.legalizableOps = std::move(originalLegalizableOps);
3485 }
3486
3487 // Erase the cloned IR.
3488 clonedAncestor->erase();
3489 return status;
3490}
3491
3492LogicalResult
3493mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
3494 const FrozenRewritePatternSet &patterns,
3495 ConversionConfig config) {
3496 return applyAnalysisConversion(ops: llvm::ArrayRef(op), target, patterns, config);
3497}
3498

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Transforms/Utils/DialectConversion.cpp