1//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Transforms/DialectConversion.h"
10#include "mlir/Config/mlir-config.h"
11#include "mlir/IR/Block.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/IR/IRMapping.h"
15#include "mlir/IR/Iterators.h"
16#include "mlir/Interfaces/FunctionInterfaces.h"
17#include "mlir/Rewrite/PatternApplicator.h"
18#include "llvm/ADT/ScopeExit.h"
19#include "llvm/ADT/SetVector.h"
20#include "llvm/ADT/SmallPtrSet.h"
21#include "llvm/Support/Debug.h"
22#include "llvm/Support/FormatVariadic.h"
23#include "llvm/Support/SaveAndRestore.h"
24#include "llvm/Support/ScopedPrinter.h"
25#include <optional>
26
27using namespace mlir;
28using namespace mlir::detail;
29
30#define DEBUG_TYPE "dialect-conversion"
31
32/// A utility function to log a successful result for the given reason.
33template <typename... Args>
34static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
35 LLVM_DEBUG({
36 os.unindent();
37 os.startLine() << "} -> SUCCESS";
38 if (!fmt.empty())
39 os.getOStream() << " : "
40 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
41 os.getOStream() << "\n";
42 });
43}
44
45/// A utility function to log a failure result for the given reason.
46template <typename... Args>
47static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
48 LLVM_DEBUG({
49 os.unindent();
50 os.startLine() << "} -> FAILURE : "
51 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
52 << "\n";
53 });
54}
55
56//===----------------------------------------------------------------------===//
57// ConversionValueMapping
58//===----------------------------------------------------------------------===//
59
60namespace {
61/// This class wraps a IRMapping to provide recursive lookup
62/// functionality, i.e. we will traverse if the mapped value also has a mapping.
63struct ConversionValueMapping {
64 /// Lookup a mapped value within the map. If a mapping for the provided value
65 /// does not exist then return the provided value. If `desiredType` is
66 /// non-null, returns the most recently mapped value with that type. If an
67 /// operand of that type does not exist, defaults to normal behavior.
68 Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
69
70 /// Lookup a mapped value within the map, or return null if a mapping does not
71 /// exist. If a mapping exists, this follows the same behavior of
72 /// `lookupOrDefault`.
73 Value lookupOrNull(Value from, Type desiredType = nullptr) const;
74
75 /// Map a value to the one provided.
76 void map(Value oldVal, Value newVal) {
77 LLVM_DEBUG({
78 for (Value it = newVal; it; it = mapping.lookupOrNull(it))
79 assert(it != oldVal && "inserting cyclic mapping");
80 });
81 mapping.map(from: oldVal, to: newVal);
82 }
83
84 /// Try to map a value to the one provided. Returns false if a transitive
85 /// mapping from the new value to the old value already exists, true if the
86 /// map was updated.
87 bool tryMap(Value oldVal, Value newVal);
88
89 /// Drop the last mapping for the given value.
90 void erase(Value value) { mapping.erase(from: value); }
91
92 /// Returns the inverse raw value mapping (without recursive query support).
93 DenseMap<Value, SmallVector<Value>> getInverse() const {
94 DenseMap<Value, SmallVector<Value>> inverse;
95 for (auto &it : mapping.getValueMap())
96 inverse[it.second].push_back(Elt: it.first);
97 return inverse;
98 }
99
100private:
101 /// Current value mappings.
102 IRMapping mapping;
103};
104} // namespace
105
106Value ConversionValueMapping::lookupOrDefault(Value from,
107 Type desiredType) const {
108 // If there was no desired type, simply find the leaf value.
109 if (!desiredType) {
110 // If this value had a valid mapping, unmap that value as well in the case
111 // that it was also replaced.
112 while (auto mappedValue = mapping.lookupOrNull(from))
113 from = mappedValue;
114 return from;
115 }
116
117 // Otherwise, try to find the deepest value that has the desired type.
118 Value desiredValue;
119 do {
120 if (from.getType() == desiredType)
121 desiredValue = from;
122
123 Value mappedValue = mapping.lookupOrNull(from);
124 if (!mappedValue)
125 break;
126 from = mappedValue;
127 } while (true);
128
129 // If the desired value was found use it, otherwise default to the leaf value.
130 return desiredValue ? desiredValue : from;
131}
132
133Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
134 Value result = lookupOrDefault(from, desiredType);
135 if (result == from || (desiredType && result.getType() != desiredType))
136 return nullptr;
137 return result;
138}
139
140bool ConversionValueMapping::tryMap(Value oldVal, Value newVal) {
141 for (Value it = newVal; it; it = mapping.lookupOrNull(from: it))
142 if (it == oldVal)
143 return false;
144 map(oldVal, newVal);
145 return true;
146}
147
148//===----------------------------------------------------------------------===//
149// Rewriter and Translation State
150//===----------------------------------------------------------------------===//
151namespace {
152/// This class contains a snapshot of the current conversion rewriter state.
153/// This is useful when saving and undoing a set of rewrites.
154struct RewriterState {
155 RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
156 unsigned numReplacedOps)
157 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
158 numReplacedOps(numReplacedOps) {}
159
160 /// The current number of rewrites performed.
161 unsigned numRewrites;
162
163 /// The current number of ignored operations.
164 unsigned numIgnoredOperations;
165
166 /// The current number of replaced ops that are scheduled for erasure.
167 unsigned numReplacedOps;
168};
169
170//===----------------------------------------------------------------------===//
171// IR rewrites
172//===----------------------------------------------------------------------===//
173
174/// An IR rewrite that can be committed (upon success) or rolled back (upon
175/// failure).
176///
177/// The dialect conversion keeps track of IR modifications (requested by the
178/// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
179/// are directly applied to the IR as the rewriter API is used, some are applied
180/// partially, and some are delayed until the `IRRewrite` objects are committed.
181class IRRewrite {
182public:
183 /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
184 /// Enum values are ordered, so that they can be used in `classof`: first all
185 /// block rewrites, then all operation rewrites.
186 enum class Kind {
187 // Block rewrites
188 CreateBlock,
189 EraseBlock,
190 InlineBlock,
191 MoveBlock,
192 BlockTypeConversion,
193 ReplaceBlockArg,
194 // Operation rewrites
195 MoveOperation,
196 ModifyOperation,
197 ReplaceOperation,
198 CreateOperation,
199 UnresolvedMaterialization
200 };
201
202 virtual ~IRRewrite() = default;
203
204 /// Roll back the rewrite. Operations may be erased during rollback.
205 virtual void rollback() = 0;
206
207 /// Commit the rewrite. At this point, it is certain that the dialect
208 /// conversion will succeed. All IR modifications, except for operation/block
209 /// erasure, must be performed through the given rewriter.
210 ///
211 /// Instead of erasing operations/blocks, they should merely be unlinked
212 /// commit phase and finally be erased during the cleanup phase. This is
213 /// because internal dialect conversion state (such as `mapping`) may still
214 /// be using them.
215 ///
216 /// Any IR modification that was already performed before the commit phase
217 /// (e.g., insertion of an op) must be communicated to the listener that may
218 /// be attached to the given rewriter.
219 virtual void commit(RewriterBase &rewriter) {}
220
221 /// Cleanup operations/blocks. Cleanup is called after commit.
222 virtual void cleanup(RewriterBase &rewriter) {}
223
224 Kind getKind() const { return kind; }
225
226 static bool classof(const IRRewrite *rewrite) { return true; }
227
228protected:
229 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
230 : kind(kind), rewriterImpl(rewriterImpl) {}
231
232 const ConversionConfig &getConfig() const;
233
234 const Kind kind;
235 ConversionPatternRewriterImpl &rewriterImpl;
236};
237
238/// A block rewrite.
239class BlockRewrite : public IRRewrite {
240public:
241 /// Return the block that this rewrite operates on.
242 Block *getBlock() const { return block; }
243
244 static bool classof(const IRRewrite *rewrite) {
245 return rewrite->getKind() >= Kind::CreateBlock &&
246 rewrite->getKind() <= Kind::ReplaceBlockArg;
247 }
248
249protected:
250 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
251 Block *block)
252 : IRRewrite(kind, rewriterImpl), block(block) {}
253
254 // The block that this rewrite operates on.
255 Block *block;
256};
257
258/// Creation of a block. Block creations are immediately reflected in the IR.
259/// There is no extra work to commit the rewrite. During rollback, the newly
260/// created block is erased.
261class CreateBlockRewrite : public BlockRewrite {
262public:
263 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
264 : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
265
266 static bool classof(const IRRewrite *rewrite) {
267 return rewrite->getKind() == Kind::CreateBlock;
268 }
269
270 void commit(RewriterBase &rewriter) override {
271 // The block was already created and inserted. Just inform the listener.
272 if (auto *listener = rewriter.getListener())
273 listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
274 }
275
276 void rollback() override {
277 // Unlink all of the operations within this block, they will be deleted
278 // separately.
279 auto &blockOps = block->getOperations();
280 while (!blockOps.empty())
281 blockOps.remove(IT: blockOps.begin());
282 block->dropAllUses();
283 if (block->getParent())
284 block->erase();
285 else
286 delete block;
287 }
288};
289
290/// Erasure of a block. Block erasures are partially reflected in the IR. Erased
291/// blocks are immediately unlinked, but only erased during cleanup. This makes
292/// it easier to rollback a block erasure: the block is simply inserted into its
293/// original location.
294class EraseBlockRewrite : public BlockRewrite {
295public:
296 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
297 Region *region, Block *insertBeforeBlock)
298 : BlockRewrite(Kind::EraseBlock, rewriterImpl, block), region(region),
299 insertBeforeBlock(insertBeforeBlock) {}
300
301 static bool classof(const IRRewrite *rewrite) {
302 return rewrite->getKind() == Kind::EraseBlock;
303 }
304
305 ~EraseBlockRewrite() override {
306 assert(!block &&
307 "rewrite was neither rolled back nor committed/cleaned up");
308 }
309
310 void rollback() override {
311 // The block (owned by this rewrite) was not actually erased yet. It was
312 // just unlinked. Put it back into its original position.
313 assert(block && "expected block");
314 auto &blockList = region->getBlocks();
315 Region::iterator before = insertBeforeBlock
316 ? Region::iterator(insertBeforeBlock)
317 : blockList.end();
318 blockList.insert(where: before, New: block);
319 block = nullptr;
320 }
321
322 void commit(RewriterBase &rewriter) override {
323 // Erase the block.
324 assert(block && "expected block");
325 assert(block->empty() && "expected empty block");
326
327 // Notify the listener that the block is about to be erased.
328 if (auto *listener =
329 dyn_cast_or_null<RewriterBase::Listener>(Val: rewriter.getListener()))
330 listener->notifyBlockErased(block);
331 }
332
333 void cleanup(RewriterBase &rewriter) override {
334 // Erase the block.
335 block->dropAllDefinedValueUses();
336 delete block;
337 block = nullptr;
338 }
339
340private:
341 // The region in which this block was previously contained.
342 Region *region;
343
344 // The original successor of this block before it was unlinked. "nullptr" if
345 // this block was the only block in the region.
346 Block *insertBeforeBlock;
347};
348
349/// Inlining of a block. This rewrite is immediately reflected in the IR.
350/// Note: This rewrite represents only the inlining of the operations. The
351/// erasure of the inlined block is a separate rewrite.
352class InlineBlockRewrite : public BlockRewrite {
353public:
354 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
355 Block *sourceBlock, Block::iterator before)
356 : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
357 sourceBlock(sourceBlock),
358 firstInlinedInst(sourceBlock->empty() ? nullptr
359 : &sourceBlock->front()),
360 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
361 // If a listener is attached to the dialect conversion, ops must be moved
362 // one-by-one. When they are moved in bulk, notifications cannot be sent
363 // because the ops that used to be in the source block at the time of the
364 // inlining (before the "commit" phase) are unknown at the time when
365 // notifications are sent (which is during the "commit" phase).
366 assert(!getConfig().listener &&
367 "InlineBlockRewrite not supported if listener is attached");
368 }
369
370 static bool classof(const IRRewrite *rewrite) {
371 return rewrite->getKind() == Kind::InlineBlock;
372 }
373
374 void rollback() override {
375 // Put the operations from the destination block (owned by the rewrite)
376 // back into the source block.
377 if (firstInlinedInst) {
378 assert(lastInlinedInst && "expected operation");
379 sourceBlock->getOperations().splice(where: sourceBlock->begin(),
380 L2&: block->getOperations(),
381 first: Block::iterator(firstInlinedInst),
382 last: ++Block::iterator(lastInlinedInst));
383 }
384 }
385
386private:
387 // The block that originally contained the operations.
388 Block *sourceBlock;
389
390 // The first inlined operation.
391 Operation *firstInlinedInst;
392
393 // The last inlined operation.
394 Operation *lastInlinedInst;
395};
396
397/// Moving of a block. This rewrite is immediately reflected in the IR.
398class MoveBlockRewrite : public BlockRewrite {
399public:
400 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
401 Region *region, Block *insertBeforeBlock)
402 : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
403 insertBeforeBlock(insertBeforeBlock) {}
404
405 static bool classof(const IRRewrite *rewrite) {
406 return rewrite->getKind() == Kind::MoveBlock;
407 }
408
409 void commit(RewriterBase &rewriter) override {
410 // The block was already moved. Just inform the listener.
411 if (auto *listener = rewriter.getListener()) {
412 // Note: `previousIt` cannot be passed because this is a delayed
413 // notification and iterators into past IR state cannot be represented.
414 listener->notifyBlockInserted(block, /*previous=*/region,
415 /*previousIt=*/{});
416 }
417 }
418
419 void rollback() override {
420 // Move the block back to its original position.
421 Region::iterator before =
422 insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
423 region->getBlocks().splice(where: before, L2&: block->getParent()->getBlocks(), N: block);
424 }
425
426private:
427 // The region in which this block was previously contained.
428 Region *region;
429
430 // The original successor of this block before it was moved. "nullptr" if
431 // this block was the only block in the region.
432 Block *insertBeforeBlock;
433};
434
435/// This structure contains the information pertaining to an argument that has
436/// been converted.
437struct ConvertedArgInfo {
438 ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
439 Value castValue = nullptr)
440 : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
441
442 /// The start index of in the new argument list that contains arguments that
443 /// replace the original.
444 unsigned newArgIdx;
445
446 /// The number of arguments that replaced the original argument.
447 unsigned newArgSize;
448
449 /// The cast value that was created to cast from the new arguments to the
450 /// old. This only used if 'newArgSize' > 1.
451 Value castValue;
452};
453
454/// Block type conversion. This rewrite is partially reflected in the IR.
455class BlockTypeConversionRewrite : public BlockRewrite {
456public:
457 BlockTypeConversionRewrite(
458 ConversionPatternRewriterImpl &rewriterImpl, Block *block,
459 Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
460 const TypeConverter *converter)
461 : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
462 origBlock(origBlock), argInfo(argInfo), converter(converter) {}
463
464 static bool classof(const IRRewrite *rewrite) {
465 return rewrite->getKind() == Kind::BlockTypeConversion;
466 }
467
468 /// Materialize any necessary conversions for converted arguments that have
469 /// live users, using the provided `findLiveUser` to search for a user that
470 /// survives the conversion process.
471 LogicalResult
472 materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
473
474 void commit(RewriterBase &rewriter) override;
475
476 void rollback() override;
477
478private:
479 /// The original block that was requested to have its signature converted.
480 Block *origBlock;
481
482 /// The conversion information for each of the arguments. The information is
483 /// std::nullopt if the argument was dropped during conversion.
484 SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
485
486 /// The type converter used to convert the arguments.
487 const TypeConverter *converter;
488};
489
490/// Replacing a block argument. This rewrite is not immediately reflected in the
491/// IR. An internal IR mapping is updated, but the actual replacement is delayed
492/// until the rewrite is committed.
493class ReplaceBlockArgRewrite : public BlockRewrite {
494public:
495 ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
496 Block *block, BlockArgument arg)
497 : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
498
499 static bool classof(const IRRewrite *rewrite) {
500 return rewrite->getKind() == Kind::ReplaceBlockArg;
501 }
502
503 void commit(RewriterBase &rewriter) override;
504
505 void rollback() override;
506
507private:
508 BlockArgument arg;
509};
510
511/// An operation rewrite.
512class OperationRewrite : public IRRewrite {
513public:
514 /// Return the operation that this rewrite operates on.
515 Operation *getOperation() const { return op; }
516
517 static bool classof(const IRRewrite *rewrite) {
518 return rewrite->getKind() >= Kind::MoveOperation &&
519 rewrite->getKind() <= Kind::UnresolvedMaterialization;
520 }
521
522protected:
523 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
524 Operation *op)
525 : IRRewrite(kind, rewriterImpl), op(op) {}
526
527 // The operation that this rewrite operates on.
528 Operation *op;
529};
530
531/// Moving of an operation. This rewrite is immediately reflected in the IR.
532class MoveOperationRewrite : public OperationRewrite {
533public:
534 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
535 Operation *op, Block *block, Operation *insertBeforeOp)
536 : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
537 insertBeforeOp(insertBeforeOp) {}
538
539 static bool classof(const IRRewrite *rewrite) {
540 return rewrite->getKind() == Kind::MoveOperation;
541 }
542
543 void commit(RewriterBase &rewriter) override {
544 // The operation was already moved. Just inform the listener.
545 if (auto *listener = rewriter.getListener()) {
546 // Note: `previousIt` cannot be passed because this is a delayed
547 // notification and iterators into past IR state cannot be represented.
548 listener->notifyOperationInserted(
549 op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
550 /*insertPt=*/{}));
551 }
552 }
553
554 void rollback() override {
555 // Move the operation back to its original position.
556 Block::iterator before =
557 insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
558 block->getOperations().splice(where: before, L2&: op->getBlock()->getOperations(), N: op);
559 }
560
561private:
562 // The block in which this operation was previously contained.
563 Block *block;
564
565 // The original successor of this operation before it was moved. "nullptr"
566 // if this operation was the only operation in the region.
567 Operation *insertBeforeOp;
568};
569
570/// In-place modification of an op. This rewrite is immediately reflected in
571/// the IR. The previous state of the operation is stored in this object.
572class ModifyOperationRewrite : public OperationRewrite {
573public:
574 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
575 Operation *op)
576 : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
577 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
578 operands(op->operand_begin(), op->operand_end()),
579 successors(op->successor_begin(), op->successor_end()) {
580 if (OpaqueProperties prop = op->getPropertiesStorage()) {
581 // Make a copy of the properties.
582 propertiesStorage = operator new(op->getPropertiesStorageSize());
583 OpaqueProperties propCopy(propertiesStorage);
584 name.initOpProperties(storage: propCopy, /*init=*/prop);
585 }
586 }
587
588 static bool classof(const IRRewrite *rewrite) {
589 return rewrite->getKind() == Kind::ModifyOperation;
590 }
591
592 ~ModifyOperationRewrite() override {
593 assert(!propertiesStorage &&
594 "rewrite was neither committed nor rolled back");
595 }
596
597 void commit(RewriterBase &rewriter) override {
598 // Notify the listener that the operation was modified in-place.
599 if (auto *listener =
600 dyn_cast_or_null<RewriterBase::Listener>(Val: rewriter.getListener()))
601 listener->notifyOperationModified(op);
602
603 if (propertiesStorage) {
604 OpaqueProperties propCopy(propertiesStorage);
605 // Note: The operation may have been erased in the mean time, so
606 // OperationName must be stored in this object.
607 name.destroyOpProperties(properties: propCopy);
608 operator delete(propertiesStorage);
609 propertiesStorage = nullptr;
610 }
611 }
612
613 void rollback() override {
614 op->setLoc(loc);
615 op->setAttrs(attrs);
616 op->setOperands(operands);
617 for (const auto &it : llvm::enumerate(First&: successors))
618 op->setSuccessor(block: it.value(), index: it.index());
619 if (propertiesStorage) {
620 OpaqueProperties propCopy(propertiesStorage);
621 op->copyProperties(rhs: propCopy);
622 name.destroyOpProperties(properties: propCopy);
623 operator delete(propertiesStorage);
624 propertiesStorage = nullptr;
625 }
626 }
627
628private:
629 OperationName name;
630 LocationAttr loc;
631 DictionaryAttr attrs;
632 SmallVector<Value, 8> operands;
633 SmallVector<Block *, 2> successors;
634 void *propertiesStorage = nullptr;
635};
636
637/// Replacing an operation. Erasing an operation is treated as a special case
638/// with "null" replacements. This rewrite is not immediately reflected in the
639/// IR. An internal IR mapping is updated, but values are not replaced and the
640/// original op is not erased until the rewrite is committed.
641class ReplaceOperationRewrite : public OperationRewrite {
642public:
643 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
644 Operation *op, const TypeConverter *converter,
645 bool changedResults)
646 : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
647 converter(converter), changedResults(changedResults) {}
648
649 static bool classof(const IRRewrite *rewrite) {
650 return rewrite->getKind() == Kind::ReplaceOperation;
651 }
652
653 void commit(RewriterBase &rewriter) override;
654
655 void rollback() override;
656
657 void cleanup(RewriterBase &rewriter) override;
658
659 const TypeConverter *getConverter() const { return converter; }
660
661 bool hasChangedResults() const { return changedResults; }
662
663private:
664 /// An optional type converter that can be used to materialize conversions
665 /// between the new and old values if necessary.
666 const TypeConverter *converter;
667
668 /// A boolean flag that indicates whether result types have changed or not.
669 bool changedResults;
670};
671
672class CreateOperationRewrite : public OperationRewrite {
673public:
674 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
675 Operation *op)
676 : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
677
678 static bool classof(const IRRewrite *rewrite) {
679 return rewrite->getKind() == Kind::CreateOperation;
680 }
681
682 void commit(RewriterBase &rewriter) override {
683 // The operation was already created and inserted. Just inform the listener.
684 if (auto *listener = rewriter.getListener())
685 listener->notifyOperationInserted(op, /*previous=*/{});
686 }
687
688 void rollback() override;
689};
690
691/// The type of materialization.
692enum MaterializationKind {
693 /// This materialization materializes a conversion for an illegal block
694 /// argument type, to a legal one.
695 Argument,
696
697 /// This materialization materializes a conversion from an illegal type to a
698 /// legal one.
699 Target
700};
701
702/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
703/// op. Unresolved materializations are erased at the end of the dialect
704/// conversion.
705class UnresolvedMaterializationRewrite : public OperationRewrite {
706public:
707 UnresolvedMaterializationRewrite(
708 ConversionPatternRewriterImpl &rewriterImpl,
709 UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
710 MaterializationKind kind = MaterializationKind::Target,
711 Type origOutputType = nullptr)
712 : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
713 converterAndKind(converter, kind), origOutputType(origOutputType) {}
714
715 static bool classof(const IRRewrite *rewrite) {
716 return rewrite->getKind() == Kind::UnresolvedMaterialization;
717 }
718
719 UnrealizedConversionCastOp getOperation() const {
720 return cast<UnrealizedConversionCastOp>(op);
721 }
722
723 void rollback() override;
724
725 void cleanup(RewriterBase &rewriter) override;
726
727 /// Return the type converter of this materialization (which may be null).
728 const TypeConverter *getConverter() const {
729 return converterAndKind.getPointer();
730 }
731
732 /// Return the kind of this materialization.
733 MaterializationKind getMaterializationKind() const {
734 return converterAndKind.getInt();
735 }
736
737 /// Set the kind of this materialization.
738 void setMaterializationKind(MaterializationKind kind) {
739 converterAndKind.setInt(kind);
740 }
741
742 /// Return the original illegal output type of the input values.
743 Type getOrigOutputType() const { return origOutputType; }
744
745private:
746 /// The corresponding type converter to use when resolving this
747 /// materialization, and the kind of this materialization.
748 llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
749 converterAndKind;
750
751 /// The original output type. This is only used for argument conversions.
752 Type origOutputType;
753};
754} // namespace
755
756/// Return "true" if there is an operation rewrite that matches the specified
757/// rewrite type and operation among the given rewrites.
758template <typename RewriteTy, typename R>
759static bool hasRewrite(R &&rewrites, Operation *op) {
760 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
761 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
762 return rewriteTy && rewriteTy->getOperation() == op;
763 });
764}
765
766/// Find the single rewrite object of the specified type and block among the
767/// given rewrites. In debug mode, asserts that there is mo more than one such
768/// object. Return "nullptr" if no object was found.
769template <typename RewriteTy, typename R>
770static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
771 RewriteTy *result = nullptr;
772 for (auto &rewrite : rewrites) {
773 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
774 if (rewriteTy && rewriteTy->getBlock() == block) {
775#ifndef NDEBUG
776 assert(!result && "expected single matching rewrite");
777 result = rewriteTy;
778#else
779 return rewriteTy;
780#endif // NDEBUG
781 }
782 }
783 return result;
784}
785
786//===----------------------------------------------------------------------===//
787// ConversionPatternRewriterImpl
788//===----------------------------------------------------------------------===//
789namespace mlir {
790namespace detail {
791struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
792 explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
793 const ConversionConfig &config)
794 : context(ctx), config(config) {}
795
796 //===--------------------------------------------------------------------===//
797 // State Management
798 //===--------------------------------------------------------------------===//
799
800 /// Return the current state of the rewriter.
801 RewriterState getCurrentState();
802
803 /// Apply all requested operation rewrites. This method is invoked when the
804 /// conversion process succeeds.
805 void applyRewrites();
806
807 /// Reset the state of the rewriter to a previously saved point.
808 void resetState(RewriterState state);
809
810 /// Append a rewrite. Rewrites are committed upon success and rolled back upon
811 /// failure.
812 template <typename RewriteTy, typename... Args>
813 void appendRewrite(Args &&...args) {
814 rewrites.push_back(
815 std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
816 }
817
818 /// Undo the rewrites (motions, splits) one by one in reverse order until
819 /// "numRewritesToKeep" rewrites remains.
820 void undoRewrites(unsigned numRewritesToKeep = 0);
821
822 /// Remap the given values to those with potentially different types. Returns
823 /// success if the values could be remapped, failure otherwise. `valueDiagTag`
824 /// is the tag used when describing a value within a diagnostic, e.g.
825 /// "operand".
826 LogicalResult remapValues(StringRef valueDiagTag,
827 std::optional<Location> inputLoc,
828 PatternRewriter &rewriter, ValueRange values,
829 SmallVectorImpl<Value> &remapped);
830
831 /// Return "true" if the given operation is ignored, and does not need to be
832 /// converted.
833 bool isOpIgnored(Operation *op) const;
834
835 /// Return "true" if the given operation was replaced or erased.
836 bool wasOpReplaced(Operation *op) const;
837
838 //===--------------------------------------------------------------------===//
839 // Type Conversion
840 //===--------------------------------------------------------------------===//
841
842 /// Attempt to convert the signature of the given block, if successful a new
843 /// block is returned containing the new arguments. Returns `block` if it did
844 /// not require conversion.
845 FailureOr<Block *> convertBlockSignature(
846 ConversionPatternRewriter &rewriter, Block *block,
847 const TypeConverter *converter,
848 TypeConverter::SignatureConversion *conversion = nullptr);
849
850 /// Convert the types of non-entry block arguments within the given region.
851 LogicalResult convertNonEntryRegionTypes(
852 ConversionPatternRewriter &rewriter, Region *region,
853 const TypeConverter &converter,
854 ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
855
856 /// Apply a signature conversion on the given region, using `converter` for
857 /// materializations if not null.
858 Block *
859 applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region,
860 TypeConverter::SignatureConversion &conversion,
861 const TypeConverter *converter);
862
863 /// Convert the types of block arguments within the given region.
864 FailureOr<Block *>
865 convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
866 const TypeConverter &converter,
867 TypeConverter::SignatureConversion *entryConversion);
868
869 /// Apply the given signature conversion on the given block. The new block
870 /// containing the updated signature is returned. If no conversions were
871 /// necessary, e.g. if the block has no arguments, `block` is returned.
872 /// `converter` is used to generate any necessary cast operations that
873 /// translate between the origin argument types and those specified in the
874 /// signature conversion.
875 Block *applySignatureConversion(
876 ConversionPatternRewriter &rewriter, Block *block,
877 const TypeConverter *converter,
878 TypeConverter::SignatureConversion &signatureConversion);
879
880 //===--------------------------------------------------------------------===//
881 // Materializations
882 //===--------------------------------------------------------------------===//
883 /// Build an unresolved materialization operation given an output type and set
884 /// of input operands.
885 Value buildUnresolvedMaterialization(MaterializationKind kind,
886 Block *insertBlock,
887 Block::iterator insertPt, Location loc,
888 ValueRange inputs, Type outputType,
889 Type origOutputType,
890 const TypeConverter *converter);
891
892 Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
893 ValueRange inputs,
894 Type origOutputType,
895 Type outputType,
896 const TypeConverter *converter);
897
898 Value buildUnresolvedTargetMaterialization(Location loc, Value input,
899 Type outputType,
900 const TypeConverter *converter);
901
902 //===--------------------------------------------------------------------===//
903 // Rewriter Notification Hooks
904 //===--------------------------------------------------------------------===//
905
906 //// Notifies that an op was inserted.
907 void notifyOperationInserted(Operation *op,
908 OpBuilder::InsertPoint previous) override;
909
910 /// Notifies that an op is about to be replaced with the given values.
911 void notifyOpReplaced(Operation *op, ValueRange newValues);
912
913 /// Notifies that a block is about to be erased.
914 void notifyBlockIsBeingErased(Block *block);
915
916 /// Notifies that a block was inserted.
917 void notifyBlockInserted(Block *block, Region *previous,
918 Region::iterator previousIt) override;
919
920 /// Notifies that a block is being inlined into another block.
921 void notifyBlockBeingInlined(Block *block, Block *srcBlock,
922 Block::iterator before);
923
924 /// Notifies that a pattern match failed for the given reason.
925 void
926 notifyMatchFailure(Location loc,
927 function_ref<void(Diagnostic &)> reasonCallback) override;
928
929 //===--------------------------------------------------------------------===//
930 // IR Erasure
931 //===--------------------------------------------------------------------===//
932
933 /// A rewriter that keeps track of erased ops and blocks. It ensures that no
934 /// operation or block is erased multiple times. This rewriter assumes that
935 /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
936 struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
937 public:
938 SingleEraseRewriter(MLIRContext *context)
939 : RewriterBase(context, /*listener=*/this) {}
940
941 /// Erase the given op (unless it was already erased).
942 void eraseOp(Operation *op) override {
943 if (erased.contains(V: op))
944 return;
945 op->dropAllUses();
946 RewriterBase::eraseOp(op);
947 }
948
949 /// Erase the given block (unless it was already erased).
950 void eraseBlock(Block *block) override {
951 if (erased.contains(V: block))
952 return;
953 assert(block->empty() && "expected empty block");
954 block->dropAllDefinedValueUses();
955 RewriterBase::eraseBlock(block);
956 }
957
958 void notifyOperationErased(Operation *op) override { erased.insert(V: op); }
959
960 void notifyBlockErased(Block *block) override { erased.insert(V: block); }
961
962 /// Pointers to all erased operations and blocks.
963 DenseSet<void *> erased;
964 };
965
966 //===--------------------------------------------------------------------===//
967 // State
968 //===--------------------------------------------------------------------===//
969
970 /// MLIR context.
971 MLIRContext *context;
972
973 // Mapping between replaced values that differ in type. This happens when
974 // replacing a value with one of a different type.
975 ConversionValueMapping mapping;
976
977 /// Ordered list of block operations (creations, splits, motions).
978 SmallVector<std::unique_ptr<IRRewrite>> rewrites;
979
980 /// A set of operations that should no longer be considered for legalization.
981 /// E.g., ops that are recursively legal. Ops that were replaced/erased are
982 /// tracked separately.
983 SetVector<Operation *> ignoredOps;
984
985 /// A set of operations that were replaced/erased. Such ops are not erased
986 /// immediately but only when the dialect conversion succeeds. In the mean
987 /// time, they should no longer be considered for legalization and any attempt
988 /// to modify/access them is invalid rewriter API usage.
989 SetVector<Operation *> replacedOps;
990
991 /// The current type converter, or nullptr if no type converter is currently
992 /// active.
993 const TypeConverter *currentTypeConverter = nullptr;
994
995 /// A mapping of regions to type converters that should be used when
996 /// converting the arguments of blocks within that region.
997 DenseMap<Region *, const TypeConverter *> regionToConverter;
998
999 /// Dialect conversion configuration.
1000 const ConversionConfig &config;
1001
1002#ifndef NDEBUG
1003 /// A set of operations that have pending updates. This tracking isn't
1004 /// strictly necessary, and is thus only active during debug builds for extra
1005 /// verification.
1006 SmallPtrSet<Operation *, 1> pendingRootUpdates;
1007
1008 /// A logger used to emit diagnostics during the conversion process.
1009 llvm::ScopedPrinter logger{llvm::dbgs()};
1010#endif
1011};
1012} // namespace detail
1013} // namespace mlir
1014
1015const ConversionConfig &IRRewrite::getConfig() const {
1016 return rewriterImpl.config;
1017}
1018
1019void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1020 // Inform the listener about all IR modifications that have already taken
1021 // place: References to the original block have been replaced with the new
1022 // block.
1023 if (auto *listener =
1024 dyn_cast_or_null<RewriterBase::Listener>(Val: rewriter.getListener()))
1025 for (Operation *op : block->getUsers())
1026 listener->notifyOperationModified(op);
1027
1028 // Process the remapping for each of the original arguments.
1029 for (auto [origArg, info] :
1030 llvm::zip_equal(t: origBlock->getArguments(), u&: argInfo)) {
1031 // Handle the case of a 1->0 value mapping.
1032 if (!info) {
1033 if (Value newArg =
1034 rewriterImpl.mapping.lookupOrNull(from: origArg, desiredType: origArg.getType()))
1035 rewriter.replaceAllUsesWith(from: origArg, to: newArg);
1036 continue;
1037 }
1038
1039 // Otherwise this is a 1->1+ value mapping.
1040 Value castValue = info->castValue;
1041 assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
1042
1043 // If the argument is still used, replace it with the generated cast.
1044 if (!origArg.use_empty()) {
1045 rewriter.replaceAllUsesWith(from: origArg, to: rewriterImpl.mapping.lookupOrDefault(
1046 from: castValue, desiredType: origArg.getType()));
1047 }
1048 }
1049}
1050
1051void BlockTypeConversionRewrite::rollback() {
1052 block->replaceAllUsesWith(newValue&: origBlock);
1053}
1054
1055LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
1056 function_ref<Operation *(Value)> findLiveUser) {
1057 // Process the remapping for each of the original arguments.
1058 for (auto it : llvm::enumerate(First: origBlock->getArguments())) {
1059 BlockArgument origArg = it.value();
1060 // Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
1061 OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
1062 builder.setInsertionPointToStart(block);
1063
1064 // If the type of this argument changed and the argument is still live, we
1065 // need to materialize a conversion.
1066 if (rewriterImpl.mapping.lookupOrNull(from: origArg, desiredType: origArg.getType()))
1067 continue;
1068 Operation *liveUser = findLiveUser(origArg);
1069 if (!liveUser)
1070 continue;
1071
1072 Value replacementValue = rewriterImpl.mapping.lookupOrDefault(from: origArg);
1073 bool isDroppedArg = replacementValue == origArg;
1074 if (!isDroppedArg)
1075 builder.setInsertionPointAfterValue(replacementValue);
1076 Value newArg;
1077 if (converter) {
1078 newArg = converter->materializeSourceConversion(
1079 builder, loc: origArg.getLoc(), resultType: origArg.getType(),
1080 inputs: isDroppedArg ? ValueRange() : ValueRange(replacementValue));
1081 assert((!newArg || newArg.getType() == origArg.getType()) &&
1082 "materialization hook did not provide a value of the expected "
1083 "type");
1084 }
1085 if (!newArg) {
1086 InFlightDiagnostic diag =
1087 emitError(loc: origArg.getLoc())
1088 << "failed to materialize conversion for block argument #"
1089 << it.index() << " that remained live after conversion, type was "
1090 << origArg.getType();
1091 if (!isDroppedArg)
1092 diag << ", with target type " << replacementValue.getType();
1093 diag.attachNote(noteLoc: liveUser->getLoc())
1094 << "see existing live user here: " << *liveUser;
1095 return failure();
1096 }
1097 rewriterImpl.mapping.map(oldVal: origArg, newVal: newArg);
1098 }
1099 return success();
1100}
1101
1102void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1103 Value repl = rewriterImpl.mapping.lookupOrNull(from: arg, desiredType: arg.getType());
1104 if (!repl)
1105 return;
1106
1107 if (isa<BlockArgument>(Val: repl)) {
1108 rewriter.replaceAllUsesWith(from: arg, to: repl);
1109 return;
1110 }
1111
1112 // If the replacement value is an operation, we check to make sure that we
1113 // don't replace uses that are within the parent operation of the
1114 // replacement value.
1115 Operation *replOp = cast<OpResult>(Val&: repl).getOwner();
1116 Block *replBlock = replOp->getBlock();
1117 rewriter.replaceUsesWithIf(from: arg, to: repl, functor: [&](OpOperand &operand) {
1118 Operation *user = operand.getOwner();
1119 return user->getBlock() != replBlock || replOp->isBeforeInBlock(other: user);
1120 });
1121}
1122
1123void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(value: arg); }
1124
1125void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1126 auto *listener =
1127 dyn_cast_or_null<RewriterBase::Listener>(Val: rewriter.getListener());
1128
1129 // Compute replacement values.
1130 SmallVector<Value> replacements =
1131 llvm::map_to_vector(C: op->getResults(), F: [&](OpResult result) {
1132 return rewriterImpl.mapping.lookupOrNull(from: result, desiredType: result.getType());
1133 });
1134
1135 // Notify the listener that the operation is about to be replaced.
1136 if (listener)
1137 listener->notifyOperationReplaced(op, replacement: replacements);
1138
1139 // Replace all uses with the new values.
1140 for (auto [result, newValue] :
1141 llvm::zip_equal(t: op->getResults(), u&: replacements))
1142 if (newValue)
1143 rewriter.replaceAllUsesWith(from: result, to: newValue);
1144
1145 // The original op will be erased, so remove it from the set of unlegalized
1146 // ops.
1147 if (getConfig().unlegalizedOps)
1148 getConfig().unlegalizedOps->erase(V: op);
1149
1150 // Notify the listener that the operation (and its nested operations) was
1151 // erased.
1152 if (listener) {
1153 op->walk<WalkOrder::PostOrder>(
1154 callback: [&](Operation *op) { listener->notifyOperationErased(op); });
1155 }
1156
1157 // Do not erase the operation yet. It may still be referenced in `mapping`.
1158 // Just unlink it for now and erase it during cleanup.
1159 op->getBlock()->getOperations().remove(IT: op);
1160}
1161
1162void ReplaceOperationRewrite::rollback() {
1163 for (auto result : op->getResults())
1164 rewriterImpl.mapping.erase(value: result);
1165}
1166
1167void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1168 rewriter.eraseOp(op);
1169}
1170
1171void CreateOperationRewrite::rollback() {
1172 for (Region &region : op->getRegions()) {
1173 while (!region.getBlocks().empty())
1174 region.getBlocks().remove(IT: region.getBlocks().begin());
1175 }
1176 op->dropAllUses();
1177 op->erase();
1178}
1179
1180void UnresolvedMaterializationRewrite::rollback() {
1181 if (getMaterializationKind() == MaterializationKind::Target) {
1182 for (Value input : op->getOperands())
1183 rewriterImpl.mapping.erase(value: input);
1184 }
1185 op->erase();
1186}
1187
1188void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
1189 rewriter.eraseOp(op);
1190}
1191
1192void ConversionPatternRewriterImpl::applyRewrites() {
1193 // Commit all rewrites.
1194 IRRewriter rewriter(context, config.listener);
1195 for (auto &rewrite : rewrites)
1196 rewrite->commit(rewriter);
1197
1198 // Clean up all rewrites.
1199 SingleEraseRewriter eraseRewriter(context);
1200 for (auto &rewrite : rewrites)
1201 rewrite->cleanup(rewriter&: eraseRewriter);
1202}
1203
1204//===----------------------------------------------------------------------===//
1205// State Management
1206
1207RewriterState ConversionPatternRewriterImpl::getCurrentState() {
1208 return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1209}
1210
1211void ConversionPatternRewriterImpl::resetState(RewriterState state) {
1212 // Undo any rewrites.
1213 undoRewrites(numRewritesToKeep: state.numRewrites);
1214
1215 // Pop all of the recorded ignored operations that are no longer valid.
1216 while (ignoredOps.size() != state.numIgnoredOperations)
1217 ignoredOps.pop_back();
1218
1219 while (replacedOps.size() != state.numReplacedOps)
1220 replacedOps.pop_back();
1221}
1222
1223void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
1224 for (auto &rewrite :
1225 llvm::reverse(C: llvm::drop_begin(RangeOrContainer&: rewrites, N: numRewritesToKeep)))
1226 rewrite->rollback();
1227 rewrites.resize(N: numRewritesToKeep);
1228}
1229
1230LogicalResult ConversionPatternRewriterImpl::remapValues(
1231 StringRef valueDiagTag, std::optional<Location> inputLoc,
1232 PatternRewriter &rewriter, ValueRange values,
1233 SmallVectorImpl<Value> &remapped) {
1234 remapped.reserve(N: llvm::size(Range&: values));
1235
1236 SmallVector<Type, 1> legalTypes;
1237 for (const auto &it : llvm::enumerate(First&: values)) {
1238 Value operand = it.value();
1239 Type origType = operand.getType();
1240
1241 // If a converter was provided, get the desired legal types for this
1242 // operand.
1243 Type desiredType;
1244 if (currentTypeConverter) {
1245 // If there is no legal conversion, fail to match this pattern.
1246 legalTypes.clear();
1247 if (failed(result: currentTypeConverter->convertType(t: origType, results&: legalTypes))) {
1248 Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1249 notifyMatchFailure(loc: operandLoc, reasonCallback: [=](Diagnostic &diag) {
1250 diag << "unable to convert type for " << valueDiagTag << " #"
1251 << it.index() << ", type was " << origType;
1252 });
1253 return failure();
1254 }
1255 // TODO: There currently isn't any mechanism to do 1->N type conversion
1256 // via the PatternRewriter replacement API, so for now we just ignore it.
1257 if (legalTypes.size() == 1)
1258 desiredType = legalTypes.front();
1259 } else {
1260 // TODO: What we should do here is just set `desiredType` to `origType`
1261 // and then handle the necessary type conversions after the conversion
1262 // process has finished. Unfortunately a lot of patterns currently rely on
1263 // receiving the new operands even if the types change, so we keep the
1264 // original behavior here for now until all of the patterns relying on
1265 // this get updated.
1266 }
1267 Value newOperand = mapping.lookupOrDefault(from: operand, desiredType);
1268
1269 // Handle the case where the conversion was 1->1 and the new operand type
1270 // isn't legal.
1271 Type newOperandType = newOperand.getType();
1272 if (currentTypeConverter && desiredType && newOperandType != desiredType) {
1273 Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1274 Value castValue = buildUnresolvedTargetMaterialization(
1275 loc: operandLoc, input: newOperand, outputType: desiredType, converter: currentTypeConverter);
1276 mapping.map(oldVal: mapping.lookupOrDefault(from: newOperand), newVal: castValue);
1277 newOperand = castValue;
1278 }
1279 remapped.push_back(Elt: newOperand);
1280 }
1281 return success();
1282}
1283
1284bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
1285 // Check to see if this operation is ignored or was replaced.
1286 return replacedOps.count(key: op) || ignoredOps.count(key: op);
1287}
1288
1289bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
1290 // Check to see if this operation was replaced.
1291 return replacedOps.count(key: op);
1292}
1293
1294//===----------------------------------------------------------------------===//
1295// Type Conversion
1296
1297FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
1298 ConversionPatternRewriter &rewriter, Block *block,
1299 const TypeConverter *converter,
1300 TypeConverter::SignatureConversion *conversion) {
1301 if (conversion)
1302 return applySignatureConversion(rewriter, block, converter, signatureConversion&: *conversion);
1303
1304 // If a converter wasn't provided, and the block wasn't already converted,
1305 // there is nothing we can do.
1306 if (!converter)
1307 return failure();
1308
1309 // Try to convert the signature for the block with the provided converter.
1310 if (auto conversion = converter->convertBlockSignature(block))
1311 return applySignatureConversion(rewriter, block, converter, signatureConversion&: *conversion);
1312 return failure();
1313}
1314
1315Block *ConversionPatternRewriterImpl::applySignatureConversion(
1316 ConversionPatternRewriter &rewriter, Region *region,
1317 TypeConverter::SignatureConversion &conversion,
1318 const TypeConverter *converter) {
1319 if (!region->empty())
1320 return *convertBlockSignature(rewriter, block: &region->front(), converter,
1321 conversion: &conversion);
1322 return nullptr;
1323}
1324
1325FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
1326 ConversionPatternRewriter &rewriter, Region *region,
1327 const TypeConverter &converter,
1328 TypeConverter::SignatureConversion *entryConversion) {
1329 regionToConverter[region] = &converter;
1330 if (region->empty())
1331 return nullptr;
1332
1333 if (failed(result: convertNonEntryRegionTypes(rewriter, region, converter)))
1334 return failure();
1335
1336 FailureOr<Block *> newEntry = convertBlockSignature(
1337 rewriter, block: &region->front(), converter: &converter, conversion: entryConversion);
1338 return newEntry;
1339}
1340
1341LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
1342 ConversionPatternRewriter &rewriter, Region *region,
1343 const TypeConverter &converter,
1344 ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
1345 regionToConverter[region] = &converter;
1346 if (region->empty())
1347 return success();
1348
1349 // Convert the arguments of each block within the region.
1350 int blockIdx = 0;
1351 assert((blockConversions.empty() ||
1352 blockConversions.size() == region->getBlocks().size() - 1) &&
1353 "expected either to provide no SignatureConversions at all or to "
1354 "provide a SignatureConversion for each non-entry block");
1355
1356 for (Block &block :
1357 llvm::make_early_inc_range(Range: llvm::drop_begin(RangeOrContainer&: *region, N: 1))) {
1358 TypeConverter::SignatureConversion *blockConversion =
1359 blockConversions.empty()
1360 ? nullptr
1361 : const_cast<TypeConverter::SignatureConversion *>(
1362 &blockConversions[blockIdx++]);
1363
1364 if (failed(result: convertBlockSignature(rewriter, block: &block, converter: &converter,
1365 conversion: blockConversion)))
1366 return failure();
1367 }
1368 return success();
1369}
1370
1371Block *ConversionPatternRewriterImpl::applySignatureConversion(
1372 ConversionPatternRewriter &rewriter, Block *block,
1373 const TypeConverter *converter,
1374 TypeConverter::SignatureConversion &signatureConversion) {
1375 OpBuilder::InsertionGuard g(rewriter);
1376
1377 // If no arguments are being changed or added, there is nothing to do.
1378 unsigned origArgCount = block->getNumArguments();
1379 auto convertedTypes = signatureConversion.getConvertedTypes();
1380 if (llvm::equal(LRange: block->getArgumentTypes(), RRange&: convertedTypes))
1381 return block;
1382
1383 // Compute the locations of all block arguments in the new block.
1384 SmallVector<Location> newLocs(convertedTypes.size(),
1385 rewriter.getUnknownLoc());
1386 for (unsigned i = 0; i < origArgCount; ++i) {
1387 auto inputMap = signatureConversion.getInputMapping(input: i);
1388 if (!inputMap || inputMap->replacementValue)
1389 continue;
1390 Location origLoc = block->getArgument(i).getLoc();
1391 for (unsigned j = 0; j < inputMap->size; ++j)
1392 newLocs[inputMap->inputNo + j] = origLoc;
1393 }
1394
1395 // Insert a new block with the converted block argument types and move all ops
1396 // from the old block to the new block.
1397 Block *newBlock =
1398 rewriter.createBlock(parent: block->getParent(), insertPt: std::next(x: block->getIterator()),
1399 argTypes: convertedTypes, locs: newLocs);
1400
1401 // If a listener is attached to the dialect conversion, ops cannot be moved
1402 // to the destination block in bulk ("fast path"). This is because at the time
1403 // the notifications are sent, it is unknown which ops were moved. Instead,
1404 // ops should be moved one-by-one ("slow path"), so that a separate
1405 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1406 // a bit more efficient, so we try to do that when possible.
1407 bool fastPath = !config.listener;
1408 if (fastPath) {
1409 appendRewrite<InlineBlockRewrite>(args&: newBlock, args&: block, args: newBlock->end());
1410 newBlock->getOperations().splice(where: newBlock->end(), L2&: block->getOperations());
1411 } else {
1412 while (!block->empty())
1413 rewriter.moveOpBefore(op: &block->front(), block: newBlock, iterator: newBlock->end());
1414 }
1415
1416 // Replace all uses of the old block with the new block.
1417 block->replaceAllUsesWith(newValue&: newBlock);
1418
1419 // Remap each of the original arguments as determined by the signature
1420 // conversion.
1421 SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
1422 argInfo.resize(N: origArgCount);
1423
1424 for (unsigned i = 0; i != origArgCount; ++i) {
1425 auto inputMap = signatureConversion.getInputMapping(input: i);
1426 if (!inputMap)
1427 continue;
1428 BlockArgument origArg = block->getArgument(i);
1429
1430 // If inputMap->replacementValue is not nullptr, then the argument is
1431 // dropped and a replacement value is provided to be the remappedValue.
1432 if (inputMap->replacementValue) {
1433 assert(inputMap->size == 0 &&
1434 "invalid to provide a replacement value when the argument isn't "
1435 "dropped");
1436 mapping.map(oldVal: origArg, newVal: inputMap->replacementValue);
1437 appendRewrite<ReplaceBlockArgRewrite>(args&: block, args&: origArg);
1438 continue;
1439 }
1440
1441 // Otherwise, this is a 1->1+ mapping.
1442 auto replArgs =
1443 newBlock->getArguments().slice(N: inputMap->inputNo, M: inputMap->size);
1444 Value newArg;
1445
1446 // If this is a 1->1 mapping and the types of new and replacement arguments
1447 // match (i.e. it's an identity map), then the argument is mapped to its
1448 // original type.
1449 // FIXME: We simply pass through the replacement argument if there wasn't a
1450 // converter, which isn't great as it allows implicit type conversions to
1451 // appear. We should properly restructure this code to handle cases where a
1452 // converter isn't provided and also to properly handle the case where an
1453 // argument materialization is actually a temporary source materialization
1454 // (e.g. in the case of 1->N).
1455 if (replArgs.size() == 1 &&
1456 (!converter || replArgs[0].getType() == origArg.getType())) {
1457 newArg = replArgs.front();
1458 } else {
1459 Type origOutputType = origArg.getType();
1460
1461 // Legalize the argument output type.
1462 Type outputType = origOutputType;
1463 if (Type legalOutputType = converter->convertType(t: outputType))
1464 outputType = legalOutputType;
1465
1466 newArg = buildUnresolvedArgumentMaterialization(
1467 block: newBlock, loc: origArg.getLoc(), inputs: replArgs, origOutputType, outputType,
1468 converter);
1469 }
1470
1471 mapping.map(oldVal: origArg, newVal: newArg);
1472 appendRewrite<ReplaceBlockArgRewrite>(args&: block, args&: origArg);
1473 argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
1474 }
1475
1476 appendRewrite<BlockTypeConversionRewrite>(args&: newBlock, args&: block, args&: argInfo,
1477 args&: converter);
1478
1479 // Erase the old block. (It is just unlinked for now and will be erased during
1480 // cleanup.)
1481 rewriter.eraseBlock(block);
1482
1483 return newBlock;
1484}
1485
1486//===----------------------------------------------------------------------===//
1487// Materializations
1488//===----------------------------------------------------------------------===//
1489
1490/// Build an unresolved materialization operation given an output type and set
1491/// of input operands.
1492Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1493 MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
1494 Location loc, ValueRange inputs, Type outputType, Type origOutputType,
1495 const TypeConverter *converter) {
1496 // Avoid materializing an unnecessary cast.
1497 if (inputs.size() == 1 && inputs.front().getType() == outputType)
1498 return inputs.front();
1499
1500 // Create an unresolved materialization. We use a new OpBuilder to avoid
1501 // tracking the materialization like we do for other operations.
1502 OpBuilder builder(insertBlock, insertPt);
1503 auto convertOp =
1504 builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1505 appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1506 origOutputType);
1507 return convertOp.getResult(0);
1508}
1509Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
1510 Block *block, Location loc, ValueRange inputs, Type origOutputType,
1511 Type outputType, const TypeConverter *converter) {
1512 return buildUnresolvedMaterialization(kind: MaterializationKind::Argument, insertBlock: block,
1513 insertPt: block->begin(), loc, inputs, outputType,
1514 origOutputType, converter);
1515}
1516Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
1517 Location loc, Value input, Type outputType,
1518 const TypeConverter *converter) {
1519 Block *insertBlock = input.getParentBlock();
1520 Block::iterator insertPt = insertBlock->begin();
1521 if (OpResult inputRes = dyn_cast<OpResult>(Val&: input))
1522 insertPt = ++inputRes.getOwner()->getIterator();
1523
1524 return buildUnresolvedMaterialization(kind: MaterializationKind::Target,
1525 insertBlock, insertPt, loc, inputs: input,
1526 outputType, origOutputType: outputType, converter);
1527}
1528
1529//===----------------------------------------------------------------------===//
1530// Rewriter Notification Hooks
1531
1532void ConversionPatternRewriterImpl::notifyOperationInserted(
1533 Operation *op, OpBuilder::InsertPoint previous) {
1534 LLVM_DEBUG({
1535 logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
1536 << ")\n";
1537 });
1538 assert(!wasOpReplaced(op->getParentOp()) &&
1539 "attempting to insert into a block within a replaced/erased op");
1540
1541 if (!previous.isSet()) {
1542 // This is a newly created op.
1543 appendRewrite<CreateOperationRewrite>(args&: op);
1544 return;
1545 }
1546 Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
1547 ? nullptr
1548 : &*previous.getPoint();
1549 appendRewrite<MoveOperationRewrite>(args&: op, args: previous.getBlock(), args&: prevOp);
1550}
1551
1552void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
1553 ValueRange newValues) {
1554 assert(newValues.size() == op->getNumResults());
1555 assert(!ignoredOps.contains(op) && "operation was already replaced");
1556
1557 // Track if any of the results changed, e.g. erased and replaced with null.
1558 bool resultChanged = false;
1559
1560 // Create mappings for each of the new result values.
1561 for (auto [newValue, result] : llvm::zip(t&: newValues, u: op->getResults())) {
1562 if (!newValue) {
1563 resultChanged = true;
1564 continue;
1565 }
1566 // Remap, and check for any result type changes.
1567 mapping.map(oldVal: result, newVal: newValue);
1568 resultChanged |= (newValue.getType() != result.getType());
1569 }
1570
1571 appendRewrite<ReplaceOperationRewrite>(args&: op, args&: currentTypeConverter,
1572 args&: resultChanged);
1573
1574 // Mark this operation and all nested ops as replaced.
1575 op->walk(callback: [&](Operation *op) { replacedOps.insert(X: op); });
1576}
1577
1578void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
1579 Region *region = block->getParent();
1580 Block *origNextBlock = block->getNextNode();
1581 appendRewrite<EraseBlockRewrite>(args&: block, args&: region, args&: origNextBlock);
1582}
1583
1584void ConversionPatternRewriterImpl::notifyBlockInserted(
1585 Block *block, Region *previous, Region::iterator previousIt) {
1586 assert(!wasOpReplaced(block->getParentOp()) &&
1587 "attempting to insert into a region within a replaced/erased op");
1588 LLVM_DEBUG(
1589 {
1590 Operation *parent = block->getParentOp();
1591 if (parent) {
1592 logger.startLine() << "** Insert Block into : '" << parent->getName()
1593 << "'(" << parent << ")\n";
1594 } else {
1595 logger.startLine()
1596 << "** Insert Block into detached Region (nullptr parent op)'";
1597 }
1598 });
1599
1600 if (!previous) {
1601 // This is a newly created block.
1602 appendRewrite<CreateBlockRewrite>(args&: block);
1603 return;
1604 }
1605 Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
1606 appendRewrite<MoveBlockRewrite>(args&: block, args&: previous, args&: prevBlock);
1607}
1608
1609void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
1610 Block *block, Block *srcBlock, Block::iterator before) {
1611 appendRewrite<InlineBlockRewrite>(args&: block, args&: srcBlock, args&: before);
1612}
1613
1614void ConversionPatternRewriterImpl::notifyMatchFailure(
1615 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1616 LLVM_DEBUG({
1617 Diagnostic diag(loc, DiagnosticSeverity::Remark);
1618 reasonCallback(diag);
1619 logger.startLine() << "** Failure : " << diag.str() << "\n";
1620 if (config.notifyCallback)
1621 config.notifyCallback(diag);
1622 });
1623}
1624
1625//===----------------------------------------------------------------------===//
1626// ConversionPatternRewriter
1627//===----------------------------------------------------------------------===//
1628
1629ConversionPatternRewriter::ConversionPatternRewriter(
1630 MLIRContext *ctx, const ConversionConfig &config)
1631 : PatternRewriter(ctx),
1632 impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
1633 setListener(impl.get());
1634}
1635
1636ConversionPatternRewriter::~ConversionPatternRewriter() = default;
1637
1638void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
1639 assert(op && newOp && "expected non-null op");
1640 replaceOp(op, newValues: newOp->getResults());
1641}
1642
1643void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
1644 assert(op->getNumResults() == newValues.size() &&
1645 "incorrect # of replacement values");
1646 LLVM_DEBUG({
1647 impl->logger.startLine()
1648 << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1649 });
1650 impl->notifyOpReplaced(op, newValues);
1651}
1652
1653void ConversionPatternRewriter::eraseOp(Operation *op) {
1654 LLVM_DEBUG({
1655 impl->logger.startLine()
1656 << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1657 });
1658 SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
1659 impl->notifyOpReplaced(op, newValues: nullRepls);
1660}
1661
1662void ConversionPatternRewriter::eraseBlock(Block *block) {
1663 assert(!impl->wasOpReplaced(block->getParentOp()) &&
1664 "attempting to erase a block within a replaced/erased op");
1665
1666 // Mark all ops for erasure.
1667 for (Operation &op : *block)
1668 eraseOp(op: &op);
1669
1670 // Unlink the block from its parent region. The block is kept in the rewrite
1671 // object and will be actually destroyed when rewrites are applied. This
1672 // allows us to keep the operations in the block live and undo the removal by
1673 // re-inserting the block.
1674 impl->notifyBlockIsBeingErased(block);
1675 block->getParent()->getBlocks().remove(IT: block);
1676}
1677
1678Block *ConversionPatternRewriter::applySignatureConversion(
1679 Region *region, TypeConverter::SignatureConversion &conversion,
1680 const TypeConverter *converter) {
1681 assert(!impl->wasOpReplaced(region->getParentOp()) &&
1682 "attempting to apply a signature conversion to a block within a "
1683 "replaced/erased op");
1684 return impl->applySignatureConversion(rewriter&: *this, region, conversion, converter);
1685}
1686
1687FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
1688 Region *region, const TypeConverter &converter,
1689 TypeConverter::SignatureConversion *entryConversion) {
1690 assert(!impl->wasOpReplaced(region->getParentOp()) &&
1691 "attempting to apply a signature conversion to a block within a "
1692 "replaced/erased op");
1693 return impl->convertRegionTypes(rewriter&: *this, region, converter, entryConversion);
1694}
1695
1696LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
1697 Region *region, const TypeConverter &converter,
1698 ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
1699 assert(!impl->wasOpReplaced(region->getParentOp()) &&
1700 "attempting to apply a signature conversion to a block within a "
1701 "replaced/erased op");
1702 return impl->convertNonEntryRegionTypes(rewriter&: *this, region, converter,
1703 blockConversions);
1704}
1705
1706void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
1707 Value to) {
1708 LLVM_DEBUG({
1709 Operation *parentOp = from.getOwner()->getParentOp();
1710 impl->logger.startLine() << "** Replace Argument : '" << from
1711 << "'(in region of '" << parentOp->getName()
1712 << "'(" << from.getOwner()->getParentOp() << ")\n";
1713 });
1714 impl->appendRewrite<ReplaceBlockArgRewrite>(args: from.getOwner(), args&: from);
1715 impl->mapping.map(oldVal: impl->mapping.lookupOrDefault(from), newVal: to);
1716}
1717
1718Value ConversionPatternRewriter::getRemappedValue(Value key) {
1719 SmallVector<Value> remappedValues;
1720 if (failed(result: impl->remapValues(valueDiagTag: "value", /*inputLoc=*/std::nullopt, rewriter&: *this, values: key,
1721 remapped&: remappedValues)))
1722 return nullptr;
1723 return remappedValues.front();
1724}
1725
1726LogicalResult
1727ConversionPatternRewriter::getRemappedValues(ValueRange keys,
1728 SmallVectorImpl<Value> &results) {
1729 if (keys.empty())
1730 return success();
1731 return impl->remapValues(valueDiagTag: "value", /*inputLoc=*/std::nullopt, rewriter&: *this, values: keys,
1732 remapped&: results);
1733}
1734
1735void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
1736 Block::iterator before,
1737 ValueRange argValues) {
1738#ifndef NDEBUG
1739 assert(argValues.size() == source->getNumArguments() &&
1740 "incorrect # of argument replacement values");
1741 assert(!impl->wasOpReplaced(source->getParentOp()) &&
1742 "attempting to inline a block from a replaced/erased op");
1743 assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1744 "attempting to inline a block into a replaced/erased op");
1745 auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1746 // The source block will be deleted, so it should not have any users (i.e.,
1747 // there should be no predecessors).
1748 assert(llvm::all_of(source->getUsers(), opIgnored) &&
1749 "expected 'source' to have no predecessors");
1750#endif // NDEBUG
1751
1752 // If a listener is attached to the dialect conversion, ops cannot be moved
1753 // to the destination block in bulk ("fast path"). This is because at the time
1754 // the notifications are sent, it is unknown which ops were moved. Instead,
1755 // ops should be moved one-by-one ("slow path"), so that a separate
1756 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1757 // a bit more efficient, so we try to do that when possible.
1758 bool fastPath = !impl->config.listener;
1759
1760 if (fastPath)
1761 impl->notifyBlockBeingInlined(block: dest, srcBlock: source, before);
1762
1763 // Replace all uses of block arguments.
1764 for (auto it : llvm::zip(t: source->getArguments(), u&: argValues))
1765 replaceUsesOfBlockArgument(from: std::get<0>(t&: it), to: std::get<1>(t&: it));
1766
1767 if (fastPath) {
1768 // Move all ops at once.
1769 dest->getOperations().splice(where: before, L2&: source->getOperations());
1770 } else {
1771 // Move op by op.
1772 while (!source->empty())
1773 moveOpBefore(op: &source->front(), block: dest, iterator: before);
1774 }
1775
1776 // Erase the source block.
1777 eraseBlock(block: source);
1778}
1779
1780void ConversionPatternRewriter::startOpModification(Operation *op) {
1781 assert(!impl->wasOpReplaced(op) &&
1782 "attempting to modify a replaced/erased op");
1783#ifndef NDEBUG
1784 impl->pendingRootUpdates.insert(Ptr: op);
1785#endif
1786 impl->appendRewrite<ModifyOperationRewrite>(args&: op);
1787}
1788
1789void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
1790 assert(!impl->wasOpReplaced(op) &&
1791 "attempting to modify a replaced/erased op");
1792 PatternRewriter::finalizeOpModification(op);
1793 // There is nothing to do here, we only need to track the operation at the
1794 // start of the update.
1795#ifndef NDEBUG
1796 assert(impl->pendingRootUpdates.erase(op) &&
1797 "operation did not have a pending in-place update");
1798#endif
1799}
1800
1801void ConversionPatternRewriter::cancelOpModification(Operation *op) {
1802#ifndef NDEBUG
1803 assert(impl->pendingRootUpdates.erase(op) &&
1804 "operation did not have a pending in-place update");
1805#endif
1806 // Erase the last update for this operation.
1807 auto it = llvm::find_if(
1808 Range: llvm::reverse(C&: impl->rewrites), P: [&](std::unique_ptr<IRRewrite> &rewrite) {
1809 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(Val: rewrite.get());
1810 return modifyRewrite && modifyRewrite->getOperation() == op;
1811 });
1812 assert(it != impl->rewrites.rend() && "no root update started on op");
1813 (*it)->rollback();
1814 int updateIdx = std::prev(x: impl->rewrites.rend()) - it;
1815 impl->rewrites.erase(CI: impl->rewrites.begin() + updateIdx);
1816}
1817
1818detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
1819 return *impl;
1820}
1821
1822//===----------------------------------------------------------------------===//
1823// ConversionPattern
1824//===----------------------------------------------------------------------===//
1825
1826LogicalResult
1827ConversionPattern::matchAndRewrite(Operation *op,
1828 PatternRewriter &rewriter) const {
1829 auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1830 auto &rewriterImpl = dialectRewriter.getImpl();
1831
1832 // Track the current conversion pattern type converter in the rewriter.
1833 llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1834 getTypeConverter());
1835
1836 // Remap the operands of the operation.
1837 SmallVector<Value, 4> operands;
1838 if (failed(result: rewriterImpl.remapValues(valueDiagTag: "operand", inputLoc: op->getLoc(), rewriter,
1839 values: op->getOperands(), remapped&: operands))) {
1840 return failure();
1841 }
1842 return matchAndRewrite(op, operands, rewriter&: dialectRewriter);
1843}
1844
1845//===----------------------------------------------------------------------===//
1846// OperationLegalizer
1847//===----------------------------------------------------------------------===//
1848
1849namespace {
1850/// A set of rewrite patterns that can be used to legalize a given operation.
1851using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1852
1853/// This class defines a recursive operation legalizer.
1854class OperationLegalizer {
1855public:
1856 using LegalizationAction = ConversionTarget::LegalizationAction;
1857
1858 OperationLegalizer(const ConversionTarget &targetInfo,
1859 const FrozenRewritePatternSet &patterns,
1860 const ConversionConfig &config);
1861
1862 /// Returns true if the given operation is known to be illegal on the target.
1863 bool isIllegal(Operation *op) const;
1864
1865 /// Attempt to legalize the given operation. Returns success if the operation
1866 /// was legalized, failure otherwise.
1867 LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1868
1869 /// Returns the conversion target in use by the legalizer.
1870 const ConversionTarget &getTarget() { return target; }
1871
1872private:
1873 /// Attempt to legalize the given operation by folding it.
1874 LogicalResult legalizeWithFold(Operation *op,
1875 ConversionPatternRewriter &rewriter);
1876
1877 /// Attempt to legalize the given operation by applying a pattern. Returns
1878 /// success if the operation was legalized, failure otherwise.
1879 LogicalResult legalizeWithPattern(Operation *op,
1880 ConversionPatternRewriter &rewriter);
1881
1882 /// Return true if the given pattern may be applied to the given operation,
1883 /// false otherwise.
1884 bool canApplyPattern(Operation *op, const Pattern &pattern,
1885 ConversionPatternRewriter &rewriter);
1886
1887 /// Legalize the resultant IR after successfully applying the given pattern.
1888 LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1889 ConversionPatternRewriter &rewriter,
1890 RewriterState &curState);
1891
1892 /// Legalizes the actions registered during the execution of a pattern.
1893 LogicalResult
1894 legalizePatternBlockRewrites(Operation *op,
1895 ConversionPatternRewriter &rewriter,
1896 ConversionPatternRewriterImpl &impl,
1897 RewriterState &state, RewriterState &newState);
1898 LogicalResult legalizePatternCreatedOperations(
1899 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1900 RewriterState &state, RewriterState &newState);
1901 LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1902 ConversionPatternRewriterImpl &impl,
1903 RewriterState &state,
1904 RewriterState &newState);
1905
1906 //===--------------------------------------------------------------------===//
1907 // Cost Model
1908 //===--------------------------------------------------------------------===//
1909
1910 /// Build an optimistic legalization graph given the provided patterns. This
1911 /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1912 /// patterns for operations that are not directly legal, but may be
1913 /// transitively legal for the current target given the provided patterns.
1914 void buildLegalizationGraph(
1915 LegalizationPatterns &anyOpLegalizerPatterns,
1916 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1917
1918 /// Compute the benefit of each node within the computed legalization graph.
1919 /// This orders the patterns within 'legalizerPatterns' based upon two
1920 /// criteria:
1921 /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1922 /// represent the more direct mapping to the target.
1923 /// 2) When comparing patterns with the same legalization depth, prefer the
1924 /// pattern with the highest PatternBenefit. This allows for users to
1925 /// prefer specific legalizations over others.
1926 void computeLegalizationGraphBenefit(
1927 LegalizationPatterns &anyOpLegalizerPatterns,
1928 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1929
1930 /// Compute the legalization depth when legalizing an operation of the given
1931 /// type.
1932 unsigned computeOpLegalizationDepth(
1933 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1934 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1935
1936 /// Apply the conversion cost model to the given set of patterns, and return
1937 /// the smallest legalization depth of any of the patterns. See
1938 /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1939 unsigned applyCostModelToPatterns(
1940 LegalizationPatterns &patterns,
1941 DenseMap<OperationName, unsigned> &minOpPatternDepth,
1942 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1943
1944 /// The current set of patterns that have been applied.
1945 SmallPtrSet<const Pattern *, 8> appliedPatterns;
1946
1947 /// The legalization information provided by the target.
1948 const ConversionTarget &target;
1949
1950 /// The pattern applicator to use for conversions.
1951 PatternApplicator applicator;
1952
1953 /// Dialect conversion configuration.
1954 const ConversionConfig &config;
1955};
1956} // namespace
1957
1958OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
1959 const FrozenRewritePatternSet &patterns,
1960 const ConversionConfig &config)
1961 : target(targetInfo), applicator(patterns), config(config) {
1962 // The set of patterns that can be applied to illegal operations to transform
1963 // them into legal ones.
1964 DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
1965 LegalizationPatterns anyOpLegalizerPatterns;
1966
1967 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1968 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1969}
1970
1971bool OperationLegalizer::isIllegal(Operation *op) const {
1972 return target.isIllegal(op);
1973}
1974
1975LogicalResult
1976OperationLegalizer::legalize(Operation *op,
1977 ConversionPatternRewriter &rewriter) {
1978#ifndef NDEBUG
1979 const char *logLineComment =
1980 "//===-------------------------------------------===//\n";
1981
1982 auto &logger = rewriter.getImpl().logger;
1983#endif
1984 LLVM_DEBUG({
1985 logger.getOStream() << "\n";
1986 logger.startLine() << logLineComment;
1987 logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
1988 << op << ") {\n";
1989 logger.indent();
1990
1991 // If the operation has no regions, just print it here.
1992 if (op->getNumRegions() == 0) {
1993 op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
1994 logger.getOStream() << "\n\n";
1995 }
1996 });
1997
1998 // Check if this operation is legal on the target.
1999 if (auto legalityInfo = target.isLegal(op)) {
2000 LLVM_DEBUG({
2001 logSuccess(
2002 logger, "operation marked legal by the target{0}",
2003 legalityInfo->isRecursivelyLegal
2004 ? "; NOTE: operation is recursively legal; skipping internals"
2005 : "");
2006 logger.startLine() << logLineComment;
2007 });
2008
2009 // If this operation is recursively legal, mark its children as ignored so
2010 // that we don't consider them for legalization.
2011 if (legalityInfo->isRecursivelyLegal) {
2012 op->walk(callback: [&](Operation *nested) {
2013 if (op != nested)
2014 rewriter.getImpl().ignoredOps.insert(X: nested);
2015 });
2016 }
2017
2018 return success();
2019 }
2020
2021 // Check to see if the operation is ignored and doesn't need to be converted.
2022 if (rewriter.getImpl().isOpIgnored(op)) {
2023 LLVM_DEBUG({
2024 logSuccess(logger, "operation marked 'ignored' during conversion");
2025 logger.startLine() << logLineComment;
2026 });
2027 return success();
2028 }
2029
2030 // If the operation isn't legal, try to fold it in-place.
2031 // TODO: Should we always try to do this, even if the op is
2032 // already legal?
2033 if (succeeded(result: legalizeWithFold(op, rewriter))) {
2034 LLVM_DEBUG({
2035 logSuccess(logger, "operation was folded");
2036 logger.startLine() << logLineComment;
2037 });
2038 return success();
2039 }
2040
2041 // Otherwise, we need to apply a legalization pattern to this operation.
2042 if (succeeded(result: legalizeWithPattern(op, rewriter))) {
2043 LLVM_DEBUG({
2044 logSuccess(logger, "");
2045 logger.startLine() << logLineComment;
2046 });
2047 return success();
2048 }
2049
2050 LLVM_DEBUG({
2051 logFailure(logger, "no matched legalization pattern");
2052 logger.startLine() << logLineComment;
2053 });
2054 return failure();
2055}
2056
2057LogicalResult
2058OperationLegalizer::legalizeWithFold(Operation *op,
2059 ConversionPatternRewriter &rewriter) {
2060 auto &rewriterImpl = rewriter.getImpl();
2061 RewriterState curState = rewriterImpl.getCurrentState();
2062
2063 LLVM_DEBUG({
2064 rewriterImpl.logger.startLine() << "* Fold {\n";
2065 rewriterImpl.logger.indent();
2066 });
2067
2068 // Try to fold the operation.
2069 SmallVector<Value, 2> replacementValues;
2070 rewriter.setInsertionPoint(op);
2071 if (failed(result: rewriter.tryFold(op, results&: replacementValues))) {
2072 LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2073 return failure();
2074 }
2075 // An empty list of replacement values indicates that the fold was in-place.
2076 // As the operation changed, a new legalization needs to be attempted.
2077 if (replacementValues.empty())
2078 return legalize(op, rewriter);
2079
2080 // Insert a replacement for 'op' with the folded replacement values.
2081 rewriter.replaceOp(op, newValues: replacementValues);
2082
2083 // Recursively legalize any new constant operations.
2084 for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
2085 i != e; ++i) {
2086 auto *createOp =
2087 dyn_cast<CreateOperationRewrite>(Val: rewriterImpl.rewrites[i].get());
2088 if (!createOp)
2089 continue;
2090 if (failed(result: legalize(op: createOp->getOperation(), rewriter))) {
2091 LLVM_DEBUG(logFailure(rewriterImpl.logger,
2092 "failed to legalize generated constant '{0}'",
2093 createOp->getOperation()->getName()));
2094 rewriterImpl.resetState(state: curState);
2095 return failure();
2096 }
2097 }
2098
2099 LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2100 return success();
2101}
2102
2103LogicalResult
2104OperationLegalizer::legalizeWithPattern(Operation *op,
2105 ConversionPatternRewriter &rewriter) {
2106 auto &rewriterImpl = rewriter.getImpl();
2107
2108 // Functor that returns if the given pattern may be applied.
2109 auto canApply = [&](const Pattern &pattern) {
2110 bool canApply = canApplyPattern(op, pattern, rewriter);
2111 if (canApply && config.listener)
2112 config.listener->notifyPatternBegin(pattern, op);
2113 return canApply;
2114 };
2115
2116 // Functor that cleans up the rewriter state after a pattern failed to match.
2117 RewriterState curState = rewriterImpl.getCurrentState();
2118 auto onFailure = [&](const Pattern &pattern) {
2119 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2120 LLVM_DEBUG({
2121 logFailure(rewriterImpl.logger, "pattern failed to match");
2122 if (rewriterImpl.config.notifyCallback) {
2123 Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
2124 diag << "Failed to apply pattern \"" << pattern.getDebugName()
2125 << "\" on op:\n"
2126 << *op;
2127 rewriterImpl.config.notifyCallback(diag);
2128 }
2129 });
2130 if (config.listener)
2131 config.listener->notifyPatternEnd(pattern, status: failure());
2132 rewriterImpl.resetState(state: curState);
2133 appliedPatterns.erase(Ptr: &pattern);
2134 };
2135
2136 // Functor that performs additional legalization when a pattern is
2137 // successfully applied.
2138 auto onSuccess = [&](const Pattern &pattern) {
2139 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2140 auto result = legalizePatternResult(op, pattern, rewriter, curState);
2141 appliedPatterns.erase(Ptr: &pattern);
2142 if (failed(result))
2143 rewriterImpl.resetState(state: curState);
2144 if (config.listener)
2145 config.listener->notifyPatternEnd(pattern, status: result);
2146 return result;
2147 };
2148
2149 // Try to match and rewrite a pattern on this operation.
2150 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2151 onSuccess);
2152}
2153
2154bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
2155 ConversionPatternRewriter &rewriter) {
2156 LLVM_DEBUG({
2157 auto &os = rewriter.getImpl().logger;
2158 os.getOStream() << "\n";
2159 os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2160 llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2161 os.getOStream() << ")' {\n";
2162 os.indent();
2163 });
2164
2165 // Ensure that we don't cycle by not allowing the same pattern to be
2166 // applied twice in the same recursion stack if it is not known to be safe.
2167 if (!pattern.hasBoundedRewriteRecursion() &&
2168 !appliedPatterns.insert(Ptr: &pattern).second) {
2169 LLVM_DEBUG(
2170 logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2171 return false;
2172 }
2173 return true;
2174}
2175
2176LogicalResult
2177OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2178 ConversionPatternRewriter &rewriter,
2179 RewriterState &curState) {
2180 auto &impl = rewriter.getImpl();
2181
2182#ifndef NDEBUG
2183 assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2184 // Check that the root was either replaced or updated in place.
2185 auto newRewrites = llvm::drop_begin(RangeOrContainer&: impl.rewrites, N: curState.numRewrites);
2186 auto replacedRoot = [&] {
2187 return hasRewrite<ReplaceOperationRewrite>(rewrites&: newRewrites, op);
2188 };
2189 auto updatedRootInPlace = [&] {
2190 return hasRewrite<ModifyOperationRewrite>(rewrites&: newRewrites, op);
2191 };
2192 assert((replacedRoot() || updatedRootInPlace()) &&
2193 "expected pattern to replace the root operation");
2194#endif // NDEBUG
2195
2196 // Legalize each of the actions registered during application.
2197 RewriterState newState = impl.getCurrentState();
2198 if (failed(result: legalizePatternBlockRewrites(op, rewriter, impl, state&: curState,
2199 newState)) ||
2200 failed(result: legalizePatternRootUpdates(rewriter, impl, state&: curState, newState)) ||
2201 failed(result: legalizePatternCreatedOperations(rewriter, impl, state&: curState,
2202 newState))) {
2203 return failure();
2204 }
2205
2206 LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2207 return success();
2208}
2209
2210LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2211 Operation *op, ConversionPatternRewriter &rewriter,
2212 ConversionPatternRewriterImpl &impl, RewriterState &state,
2213 RewriterState &newState) {
2214 SmallPtrSet<Operation *, 16> operationsToIgnore;
2215
2216 // If the pattern moved or created any blocks, make sure the types of block
2217 // arguments get legalized.
2218 for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2219 BlockRewrite *rewrite = dyn_cast<BlockRewrite>(Val: impl.rewrites[i].get());
2220 if (!rewrite)
2221 continue;
2222 Block *block = rewrite->getBlock();
2223 if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2224 ReplaceBlockArgRewrite>(Val: rewrite))
2225 continue;
2226 // Only check blocks outside of the current operation.
2227 Operation *parentOp = block->getParentOp();
2228 if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2229 continue;
2230
2231 // If the region of the block has a type converter, try to convert the block
2232 // directly.
2233 if (auto *converter = impl.regionToConverter.lookup(Val: block->getParent())) {
2234 if (failed(result: impl.convertBlockSignature(rewriter, block, converter))) {
2235 LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2236 "block"));
2237 return failure();
2238 }
2239 continue;
2240 }
2241
2242 // Otherwise, check that this operation isn't one generated by this pattern.
2243 // This is because we will attempt to legalize the parent operation, and
2244 // blocks in regions created by this pattern will already be legalized later
2245 // on. If we haven't built the set yet, build it now.
2246 if (operationsToIgnore.empty()) {
2247 for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
2248 ++i) {
2249 auto *createOp =
2250 dyn_cast<CreateOperationRewrite>(Val: impl.rewrites[i].get());
2251 if (!createOp)
2252 continue;
2253 operationsToIgnore.insert(Ptr: createOp->getOperation());
2254 }
2255 }
2256
2257 // If this operation should be considered for re-legalization, try it.
2258 if (operationsToIgnore.insert(Ptr: parentOp).second &&
2259 failed(result: legalize(op: parentOp, rewriter))) {
2260 LLVM_DEBUG(logFailure(impl.logger,
2261 "operation '{0}'({1}) became illegal after rewrite",
2262 parentOp->getName(), parentOp));
2263 return failure();
2264 }
2265 }
2266 return success();
2267}
2268
2269LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2270 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
2271 RewriterState &state, RewriterState &newState) {
2272 for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2273 auto *createOp = dyn_cast<CreateOperationRewrite>(Val: impl.rewrites[i].get());
2274 if (!createOp)
2275 continue;
2276 Operation *op = createOp->getOperation();
2277 if (failed(result: legalize(op, rewriter))) {
2278 LLVM_DEBUG(logFailure(impl.logger,
2279 "failed to legalize generated operation '{0}'({1})",
2280 op->getName(), op));
2281 return failure();
2282 }
2283 }
2284 return success();
2285}
2286
2287LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2288 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
2289 RewriterState &state, RewriterState &newState) {
2290 for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2291 auto *rewrite = dyn_cast<ModifyOperationRewrite>(Val: impl.rewrites[i].get());
2292 if (!rewrite)
2293 continue;
2294 Operation *op = rewrite->getOperation();
2295 if (failed(result: legalize(op, rewriter))) {
2296 LLVM_DEBUG(logFailure(
2297 impl.logger, "failed to legalize operation updated in-place '{0}'",
2298 op->getName()));
2299 return failure();
2300 }
2301 }
2302 return success();
2303}
2304
2305//===----------------------------------------------------------------------===//
2306// Cost Model
2307
2308void OperationLegalizer::buildLegalizationGraph(
2309 LegalizationPatterns &anyOpLegalizerPatterns,
2310 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2311 // A mapping between an operation and a set of operations that can be used to
2312 // generate it.
2313 DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps;
2314 // A mapping between an operation and any currently invalid patterns it has.
2315 DenseMap<OperationName, SmallPtrSet<const Pattern *, 2>> invalidPatterns;
2316 // A worklist of patterns to consider for legality.
2317 SetVector<const Pattern *> patternWorklist;
2318
2319 // Build the mapping from operations to the parent ops that may generate them.
2320 applicator.walkAllPatterns(walk: [&](const Pattern &pattern) {
2321 std::optional<OperationName> root = pattern.getRootKind();
2322
2323 // If the pattern has no specific root, we can't analyze the relationship
2324 // between the root op and generated operations. Given that, add all such
2325 // patterns to the legalization set.
2326 if (!root) {
2327 anyOpLegalizerPatterns.push_back(Elt: &pattern);
2328 return;
2329 }
2330
2331 // Skip operations that are always known to be legal.
2332 if (target.getOpAction(op: *root) == LegalizationAction::Legal)
2333 return;
2334
2335 // Add this pattern to the invalid set for the root op and record this root
2336 // as a parent for any generated operations.
2337 invalidPatterns[*root].insert(Ptr: &pattern);
2338 for (auto op : pattern.getGeneratedOps())
2339 parentOps[op].insert(Ptr: *root);
2340
2341 // Add this pattern to the worklist.
2342 patternWorklist.insert(X: &pattern);
2343 });
2344
2345 // If there are any patterns that don't have a specific root kind, we can't
2346 // make direct assumptions about what operations will never be legalized.
2347 // Note: Technically we could, but it would require an analysis that may
2348 // recurse into itself. It would be better to perform this kind of filtering
2349 // at a higher level than here anyways.
2350 if (!anyOpLegalizerPatterns.empty()) {
2351 for (const Pattern *pattern : patternWorklist)
2352 legalizerPatterns[*pattern->getRootKind()].push_back(Elt: pattern);
2353 return;
2354 }
2355
2356 while (!patternWorklist.empty()) {
2357 auto *pattern = patternWorklist.pop_back_val();
2358
2359 // Check to see if any of the generated operations are invalid.
2360 if (llvm::any_of(Range: pattern->getGeneratedOps(), P: [&](OperationName op) {
2361 std::optional<LegalizationAction> action = target.getOpAction(op);
2362 return !legalizerPatterns.count(Val: op) &&
2363 (!action || action == LegalizationAction::Illegal);
2364 }))
2365 continue;
2366
2367 // Otherwise, if all of the generated operation are valid, this op is now
2368 // legal so add all of the child patterns to the worklist.
2369 legalizerPatterns[*pattern->getRootKind()].push_back(Elt: pattern);
2370 invalidPatterns[*pattern->getRootKind()].erase(Ptr: pattern);
2371
2372 // Add any invalid patterns of the parent operations to see if they have now
2373 // become legal.
2374 for (auto op : parentOps[*pattern->getRootKind()])
2375 patternWorklist.set_union(invalidPatterns[op]);
2376 }
2377}
2378
2379void OperationLegalizer::computeLegalizationGraphBenefit(
2380 LegalizationPatterns &anyOpLegalizerPatterns,
2381 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2382 // The smallest pattern depth, when legalizing an operation.
2383 DenseMap<OperationName, unsigned> minOpPatternDepth;
2384
2385 // For each operation that is transitively legal, compute a cost for it.
2386 for (auto &opIt : legalizerPatterns)
2387 if (!minOpPatternDepth.count(Val: opIt.first))
2388 computeOpLegalizationDepth(op: opIt.first, minOpPatternDepth,
2389 legalizerPatterns);
2390
2391 // Apply the cost model to the patterns that can match any operation. Those
2392 // with a specific operation type are already resolved when computing the op
2393 // legalization depth.
2394 if (!anyOpLegalizerPatterns.empty())
2395 applyCostModelToPatterns(patterns&: anyOpLegalizerPatterns, minOpPatternDepth,
2396 legalizerPatterns);
2397
2398 // Apply a cost model to the pattern applicator. We order patterns first by
2399 // depth then benefit. `legalizerPatterns` contains per-op patterns by
2400 // decreasing benefit.
2401 applicator.applyCostModel(model: [&](const Pattern &pattern) {
2402 ArrayRef<const Pattern *> orderedPatternList;
2403 if (std::optional<OperationName> rootName = pattern.getRootKind())
2404 orderedPatternList = legalizerPatterns[*rootName];
2405 else
2406 orderedPatternList = anyOpLegalizerPatterns;
2407
2408 // If the pattern is not found, then it was removed and cannot be matched.
2409 auto *it = llvm::find(Range&: orderedPatternList, Val: &pattern);
2410 if (it == orderedPatternList.end())
2411 return PatternBenefit::impossibleToMatch();
2412
2413 // Patterns found earlier in the list have higher benefit.
2414 return PatternBenefit(std::distance(first: it, last: orderedPatternList.end()));
2415 });
2416}
2417
2418unsigned OperationLegalizer::computeOpLegalizationDepth(
2419 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2420 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2421 // Check for existing depth.
2422 auto depthIt = minOpPatternDepth.find(Val: op);
2423 if (depthIt != minOpPatternDepth.end())
2424 return depthIt->second;
2425
2426 // If a mapping for this operation does not exist, then this operation
2427 // is always legal. Return 0 as the depth for a directly legal operation.
2428 auto opPatternsIt = legalizerPatterns.find(Val: op);
2429 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2430 return 0u;
2431
2432 // Record this initial depth in case we encounter this op again when
2433 // recursively computing the depth.
2434 minOpPatternDepth.try_emplace(Key: op, Args: std::numeric_limits<unsigned>::max());
2435
2436 // Apply the cost model to the operation patterns, and update the minimum
2437 // depth.
2438 unsigned minDepth = applyCostModelToPatterns(
2439 patterns&: opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2440 minOpPatternDepth[op] = minDepth;
2441 return minDepth;
2442}
2443
2444unsigned OperationLegalizer::applyCostModelToPatterns(
2445 LegalizationPatterns &patterns,
2446 DenseMap<OperationName, unsigned> &minOpPatternDepth,
2447 DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2448 unsigned minDepth = std::numeric_limits<unsigned>::max();
2449
2450 // Compute the depth for each pattern within the set.
2451 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2452 patternsByDepth.reserve(N: patterns.size());
2453 for (const Pattern *pattern : patterns) {
2454 unsigned depth = 1;
2455 for (auto generatedOp : pattern->getGeneratedOps()) {
2456 unsigned generatedOpDepth = computeOpLegalizationDepth(
2457 op: generatedOp, minOpPatternDepth, legalizerPatterns);
2458 depth = std::max(a: depth, b: generatedOpDepth + 1);
2459 }
2460 patternsByDepth.emplace_back(Args&: pattern, Args&: depth);
2461
2462 // Update the minimum depth of the pattern list.
2463 minDepth = std::min(a: minDepth, b: depth);
2464 }
2465
2466 // If the operation only has one legalization pattern, there is no need to
2467 // sort them.
2468 if (patternsByDepth.size() == 1)
2469 return minDepth;
2470
2471 // Sort the patterns by those likely to be the most beneficial.
2472 std::stable_sort(first: patternsByDepth.begin(), last: patternsByDepth.end(),
2473 comp: [](const std::pair<const Pattern *, unsigned> &lhs,
2474 const std::pair<const Pattern *, unsigned> &rhs) {
2475 // First sort by the smaller pattern legalization
2476 // depth.
2477 if (lhs.second != rhs.second)
2478 return lhs.second < rhs.second;
2479
2480 // Then sort by the larger pattern benefit.
2481 auto lhsBenefit = lhs.first->getBenefit();
2482 auto rhsBenefit = rhs.first->getBenefit();
2483 return lhsBenefit > rhsBenefit;
2484 });
2485
2486 // Update the legalization pattern to use the new sorted list.
2487 patterns.clear();
2488 for (auto &patternIt : patternsByDepth)
2489 patterns.push_back(Elt: patternIt.first);
2490 return minDepth;
2491}
2492
2493//===----------------------------------------------------------------------===//
2494// OperationConverter
2495//===----------------------------------------------------------------------===//
2496namespace {
2497enum OpConversionMode {
2498 /// In this mode, the conversion will ignore failed conversions to allow
2499 /// illegal operations to co-exist in the IR.
2500 Partial,
2501
2502 /// In this mode, all operations must be legal for the given target for the
2503 /// conversion to succeed.
2504 Full,
2505
2506 /// In this mode, operations are analyzed for legality. No actual rewrites are
2507 /// applied to the operations on success.
2508 Analysis,
2509};
2510} // namespace
2511
2512namespace mlir {
2513// This class converts operations to a given conversion target via a set of
2514// rewrite patterns. The conversion behaves differently depending on the
2515// conversion mode.
2516struct OperationConverter {
2517 explicit OperationConverter(const ConversionTarget &target,
2518 const FrozenRewritePatternSet &patterns,
2519 const ConversionConfig &config,
2520 OpConversionMode mode)
2521 : config(config), opLegalizer(target, patterns, this->config),
2522 mode(mode) {}
2523
2524 /// Converts the given operations to the conversion target.
2525 LogicalResult convertOperations(ArrayRef<Operation *> ops);
2526
2527private:
2528 /// Converts an operation with the given rewriter.
2529 LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2530
2531 /// This method is called after the conversion process to legalize any
2532 /// remaining artifacts and complete the conversion.
2533 LogicalResult finalize(ConversionPatternRewriter &rewriter);
2534
2535 /// Legalize the types of converted block arguments.
2536 LogicalResult
2537 legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
2538 ConversionPatternRewriterImpl &rewriterImpl);
2539
2540 /// Legalize any unresolved type materializations.
2541 LogicalResult legalizeUnresolvedMaterializations(
2542 ConversionPatternRewriter &rewriter,
2543 ConversionPatternRewriterImpl &rewriterImpl,
2544 std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping);
2545
2546 /// Legalize an operation result that was marked as "erased".
2547 LogicalResult
2548 legalizeErasedResult(Operation *op, OpResult result,
2549 ConversionPatternRewriterImpl &rewriterImpl);
2550
2551 /// Legalize an operation result that was replaced with a value of a different
2552 /// type.
2553 LogicalResult legalizeChangedResultType(
2554 Operation *op, OpResult result, Value newValue,
2555 const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2556 ConversionPatternRewriterImpl &rewriterImpl,
2557 const DenseMap<Value, SmallVector<Value>> &inverseMapping);
2558
2559 /// Dialect conversion configuration.
2560 ConversionConfig config;
2561
2562 /// The legalizer to use when converting operations.
2563 OperationLegalizer opLegalizer;
2564
2565 /// The conversion mode to use when legalizing operations.
2566 OpConversionMode mode;
2567};
2568} // namespace mlir
2569
2570LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2571 Operation *op) {
2572 // Legalize the given operation.
2573 if (failed(result: opLegalizer.legalize(op, rewriter))) {
2574 // Handle the case of a failed conversion for each of the different modes.
2575 // Full conversions expect all operations to be converted.
2576 if (mode == OpConversionMode::Full)
2577 return op->emitError()
2578 << "failed to legalize operation '" << op->getName() << "'";
2579 // Partial conversions allow conversions to fail iff the operation was not
2580 // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2581 // set, non-legalizable ops are added to that set.
2582 if (mode == OpConversionMode::Partial) {
2583 if (opLegalizer.isIllegal(op))
2584 return op->emitError()
2585 << "failed to legalize operation '" << op->getName()
2586 << "' that was explicitly marked illegal";
2587 if (config.unlegalizedOps)
2588 config.unlegalizedOps->insert(V: op);
2589 }
2590 } else if (mode == OpConversionMode::Analysis) {
2591 // Analysis conversions don't fail if any operations fail to legalize,
2592 // they are only interested in the operations that were successfully
2593 // legalized.
2594 if (config.legalizableOps)
2595 config.legalizableOps->insert(V: op);
2596 }
2597 return success();
2598}
2599
2600LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2601 if (ops.empty())
2602 return success();
2603 const ConversionTarget &target = opLegalizer.getTarget();
2604
2605 // Compute the set of operations and blocks to convert.
2606 SmallVector<Operation *> toConvert;
2607 for (auto *op : ops) {
2608 op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
2609 callback: [&](Operation *op) {
2610 toConvert.push_back(Elt: op);
2611 // Don't check this operation's children for conversion if the
2612 // operation is recursively legal.
2613 auto legalityInfo = target.isLegal(op);
2614 if (legalityInfo && legalityInfo->isRecursivelyLegal)
2615 return WalkResult::skip();
2616 return WalkResult::advance();
2617 });
2618 }
2619
2620 // Convert each operation and discard rewrites on failure.
2621 ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
2622 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2623
2624 for (auto *op : toConvert)
2625 if (failed(result: convert(rewriter, op)))
2626 return rewriterImpl.undoRewrites(), failure();
2627
2628 // Now that all of the operations have been converted, finalize the conversion
2629 // process to ensure any lingering conversion artifacts are cleaned up and
2630 // legalized.
2631 if (failed(result: finalize(rewriter)))
2632 return rewriterImpl.undoRewrites(), failure();
2633
2634 // After a successful conversion, apply rewrites if this is not an analysis
2635 // conversion.
2636 if (mode == OpConversionMode::Analysis) {
2637 rewriterImpl.undoRewrites();
2638 } else {
2639 rewriterImpl.applyRewrites();
2640 }
2641 return success();
2642}
2643
2644LogicalResult
2645OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2646 std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
2647 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2648 if (failed(result: legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2649 inverseMapping)) ||
2650 failed(result: legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2651 return failure();
2652
2653 // Process requested operation replacements.
2654 for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
2655 auto *opReplacement =
2656 dyn_cast<ReplaceOperationRewrite>(Val: rewriterImpl.rewrites[i].get());
2657 if (!opReplacement || !opReplacement->hasChangedResults())
2658 continue;
2659 Operation *op = opReplacement->getOperation();
2660 for (OpResult result : op->getResults()) {
2661 Value newValue = rewriterImpl.mapping.lookupOrNull(from: result);
2662
2663 // If the operation result was replaced with null, all of the uses of this
2664 // value should be replaced.
2665 if (!newValue) {
2666 if (failed(result: legalizeErasedResult(op, result, rewriterImpl)))
2667 return failure();
2668 continue;
2669 }
2670
2671 // Otherwise, check to see if the type of the result changed.
2672 if (result.getType() == newValue.getType())
2673 continue;
2674
2675 // Compute the inverse mapping only if it is really needed.
2676 if (!inverseMapping)
2677 inverseMapping = rewriterImpl.mapping.getInverse();
2678
2679 // Legalize this result.
2680 rewriter.setInsertionPoint(op);
2681 if (failed(result: legalizeChangedResultType(
2682 op, result, newValue, replConverter: opReplacement->getConverter(), rewriter,
2683 rewriterImpl, inverseMapping: *inverseMapping)))
2684 return failure();
2685 }
2686 }
2687 return success();
2688}
2689
2690LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2691 ConversionPatternRewriter &rewriter,
2692 ConversionPatternRewriterImpl &rewriterImpl) {
2693 // Functor used to check if all users of a value will be dead after
2694 // conversion.
2695 auto findLiveUser = [&](Value val) {
2696 auto liveUserIt = llvm::find_if_not(Range: val.getUsers(), P: [&](Operation *user) {
2697 return rewriterImpl.isOpIgnored(op: user);
2698 });
2699 return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2700 };
2701 // Note: `rewrites` may be reallocated as the loop is running.
2702 for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size());
2703 ++i) {
2704 auto &rewrite = rewriterImpl.rewrites[i];
2705 if (auto *blockTypeConversionRewrite =
2706 dyn_cast<BlockTypeConversionRewrite>(Val: rewrite.get()))
2707 if (failed(result: blockTypeConversionRewrite->materializeLiveConversions(
2708 findLiveUser)))
2709 return failure();
2710 }
2711 return success();
2712}
2713
2714/// Replace the results of a materialization operation with the given values.
2715static void
2716replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl,
2717 ResultRange matResults, ValueRange values,
2718 DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2719 matResults.replaceAllUsesWith(values);
2720
2721 // For each of the materialization results, update the inverse mappings to
2722 // point to the replacement values.
2723 for (auto [matResult, newValue] : llvm::zip(t&: matResults, u&: values)) {
2724 auto inverseMapIt = inverseMapping.find(Val: matResult);
2725 if (inverseMapIt == inverseMapping.end())
2726 continue;
2727
2728 // Update the reverse mapping, or remove the mapping if we couldn't update
2729 // it. Not being able to update signals that the mapping would have become
2730 // circular (i.e. %foo -> newValue -> %foo), which may occur as values are
2731 // propagated through temporary materializations. We simply drop the
2732 // mapping, and let the post-conversion replacement logic handle updating
2733 // uses.
2734 for (Value inverseMapVal : inverseMapIt->second)
2735 if (!rewriterImpl.mapping.tryMap(oldVal: inverseMapVal, newVal: newValue))
2736 rewriterImpl.mapping.erase(value: inverseMapVal);
2737 }
2738}
2739
2740/// Compute all of the unresolved materializations that will persist beyond the
2741/// conversion process, and require inserting a proper user materialization for.
2742static void computeNecessaryMaterializations(
2743 DenseMap<Operation *, UnresolvedMaterializationRewrite *>
2744 &materializationOps,
2745 ConversionPatternRewriter &rewriter,
2746 ConversionPatternRewriterImpl &rewriterImpl,
2747 DenseMap<Value, SmallVector<Value>> &inverseMapping,
2748 SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
2749 auto isLive = [&](Value value) {
2750 auto findFn = [&](Operation *user) {
2751 auto matIt = materializationOps.find(Val: user);
2752 if (matIt != materializationOps.end())
2753 return !necessaryMaterializations.count(key: matIt->second);
2754 return rewriterImpl.isOpIgnored(op: user);
2755 };
2756 // This value may be replacing another value that has a live user.
2757 for (Value inv : inverseMapping.lookup(Val: value))
2758 if (llvm::find_if_not(Range: inv.getUsers(), P: findFn) != inv.user_end())
2759 return true;
2760 // Or have live users itself.
2761 return llvm::find_if_not(Range: value.getUsers(), P: findFn) != value.user_end();
2762 };
2763
2764 llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
2765 [&](Value invalidRoot, Value value, Type type) {
2766 // Check to see if the input operation was remapped to a variant of the
2767 // output.
2768 Value remappedValue = rewriterImpl.mapping.lookupOrDefault(from: value, desiredType: type);
2769 if (remappedValue.getType() == type && remappedValue != invalidRoot)
2770 return remappedValue;
2771
2772 // Check to see if the input is a materialization operation that
2773 // provides an inverse conversion. We just check blindly for
2774 // UnrealizedConversionCastOp here, but it has no effect on correctness.
2775 auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>();
2776 if (inputCastOp && inputCastOp->getNumOperands() == 1)
2777 return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
2778 type);
2779
2780 return Value();
2781 };
2782
2783 SetVector<UnresolvedMaterializationRewrite *> worklist;
2784 for (auto &rewrite : rewriterImpl.rewrites) {
2785 auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(Val: rewrite.get());
2786 if (!mat)
2787 continue;
2788 materializationOps.try_emplace(mat->getOperation(), mat);
2789 worklist.insert(X: mat);
2790 }
2791 while (!worklist.empty()) {
2792 UnresolvedMaterializationRewrite *mat = worklist.pop_back_val();
2793 UnrealizedConversionCastOp op = mat->getOperation();
2794
2795 // We currently only handle target materializations here.
2796 assert(op->getNumResults() == 1 && "unexpected materialization type");
2797 OpResult opResult = op->getOpResult(0);
2798 Type outputType = opResult.getType();
2799 Operation::operand_range inputOperands = op.getOperands();
2800
2801 // Try to forward propagate operands for user conversion casts that result
2802 // in the input types of the current cast.
2803 for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) {
2804 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
2805 if (!castOp)
2806 continue;
2807 if (castOp->getResultTypes() == inputOperands.getTypes()) {
2808 replaceMaterialization(rewriterImpl, opResult, inputOperands,
2809 inverseMapping);
2810 necessaryMaterializations.remove(materializationOps.lookup(user));
2811 }
2812 }
2813
2814 // Try to avoid materializing a resolved materialization if possible.
2815 // Handle the case of a 1-1 materialization.
2816 if (inputOperands.size() == 1) {
2817 // Check to see if the input operation was remapped to a variant of the
2818 // output.
2819 Value remappedValue =
2820 lookupRemappedValue(opResult, inputOperands[0], outputType);
2821 if (remappedValue && remappedValue != opResult) {
2822 replaceMaterialization(rewriterImpl, matResults: opResult, values: remappedValue,
2823 inverseMapping);
2824 necessaryMaterializations.remove(X: mat);
2825 continue;
2826 }
2827 } else {
2828 // TODO: Avoid materializing other types of conversions here.
2829 }
2830
2831 // Check to see if this is an argument materialization.
2832 if (llvm::any_of(op->getOperands(), llvm::IsaPred<BlockArgument>) ||
2833 llvm::any_of(inverseMapping[op->getResult(0)],
2834 llvm::IsaPred<BlockArgument>)) {
2835 mat->setMaterializationKind(MaterializationKind::Argument);
2836 }
2837
2838 // If the materialization does not have any live users, we don't need to
2839 // generate a user materialization for it.
2840 // FIXME: For argument materializations, we currently need to check if any
2841 // of the inverse mapped values are used because some patterns expect blind
2842 // value replacement even if the types differ in some cases. When those
2843 // patterns are fixed, we can drop the argument special case here.
2844 bool isMaterializationLive = isLive(opResult);
2845 if (mat->getMaterializationKind() == MaterializationKind::Argument)
2846 isMaterializationLive |= llvm::any_of(Range&: inverseMapping[opResult], P: isLive);
2847 if (!isMaterializationLive)
2848 continue;
2849 if (!necessaryMaterializations.insert(X: mat))
2850 continue;
2851
2852 // Reprocess input materializations to see if they have an updated status.
2853 for (Value input : inputOperands) {
2854 if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
2855 if (auto *mat = materializationOps.lookup(parentOp))
2856 worklist.insert(mat);
2857 }
2858 }
2859 }
2860}
2861
2862/// Legalize the given unresolved materialization. Returns success if the
2863/// materialization was legalized, failure otherise.
2864static LogicalResult legalizeUnresolvedMaterialization(
2865 UnresolvedMaterializationRewrite &mat,
2866 DenseMap<Operation *, UnresolvedMaterializationRewrite *>
2867 &materializationOps,
2868 ConversionPatternRewriter &rewriter,
2869 ConversionPatternRewriterImpl &rewriterImpl,
2870 DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2871 auto findLiveUser = [&](auto &&users) {
2872 auto liveUserIt = llvm::find_if_not(
2873 users, [&](Operation *user) { return rewriterImpl.isOpIgnored(op: user); });
2874 return liveUserIt == users.end() ? nullptr : *liveUserIt;
2875 };
2876
2877 llvm::unique_function<Value(Value, Type)> lookupRemappedValue =
2878 [&](Value value, Type type) {
2879 // Check to see if the input operation was remapped to a variant of the
2880 // output.
2881 Value remappedValue = rewriterImpl.mapping.lookupOrDefault(from: value, desiredType: type);
2882 if (remappedValue.getType() == type)
2883 return remappedValue;
2884 return Value();
2885 };
2886
2887 UnrealizedConversionCastOp op = mat.getOperation();
2888 if (!rewriterImpl.ignoredOps.insert(op))
2889 return success();
2890
2891 // We currently only handle target materializations here.
2892 OpResult opResult = op->getOpResult(0);
2893 Operation::operand_range inputOperands = op.getOperands();
2894 Type outputType = opResult.getType();
2895
2896 // If any input to this materialization is another materialization, resolve
2897 // the input first.
2898 for (Value value : op->getOperands()) {
2899 auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>();
2900 if (!valueCast)
2901 continue;
2902
2903 auto matIt = materializationOps.find(valueCast);
2904 if (matIt != materializationOps.end())
2905 if (failed(legalizeUnresolvedMaterialization(
2906 *matIt->second, materializationOps, rewriter, rewriterImpl,
2907 inverseMapping)))
2908 return failure();
2909 }
2910
2911 // Perform a last ditch attempt to avoid materializing a resolved
2912 // materialization if possible.
2913 // Handle the case of a 1-1 materialization.
2914 if (inputOperands.size() == 1) {
2915 // Check to see if the input operation was remapped to a variant of the
2916 // output.
2917 Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
2918 if (remappedValue && remappedValue != opResult) {
2919 replaceMaterialization(rewriterImpl, matResults: opResult, values: remappedValue,
2920 inverseMapping);
2921 return success();
2922 }
2923 } else {
2924 // TODO: Avoid materializing other types of conversions here.
2925 }
2926
2927 // Try to materialize the conversion.
2928 if (const TypeConverter *converter = mat.getConverter()) {
2929 // FIXME: Determine a suitable insertion location when there are multiple
2930 // inputs.
2931 if (inputOperands.size() == 1)
2932 rewriter.setInsertionPointAfterValue(inputOperands.front());
2933 else
2934 rewriter.setInsertionPoint(op);
2935
2936 Value newMaterialization;
2937 switch (mat.getMaterializationKind()) {
2938 case MaterializationKind::Argument:
2939 // Try to materialize an argument conversion.
2940 // FIXME: The current argument materialization hook expects the original
2941 // output type, even though it doesn't use that as the actual output type
2942 // of the generated IR. The output type is just used as an indicator of
2943 // the type of materialization to do. This behavior is really awkward in
2944 // that it diverges from the behavior of the other hooks, and can be
2945 // easily misunderstood. We should clean up the argument hooks to better
2946 // represent the desired invariants we actually care about.
2947 newMaterialization = converter->materializeArgumentConversion(
2948 builder&: rewriter, loc: op->getLoc(), resultType: mat.getOrigOutputType(), inputs: inputOperands);
2949 if (newMaterialization)
2950 break;
2951
2952 // If an argument materialization failed, fallback to trying a target
2953 // materialization.
2954 [[fallthrough]];
2955 case MaterializationKind::Target:
2956 newMaterialization = converter->materializeTargetConversion(
2957 builder&: rewriter, loc: op->getLoc(), resultType: outputType, inputs: inputOperands);
2958 break;
2959 }
2960 if (newMaterialization) {
2961 replaceMaterialization(rewriterImpl, matResults: opResult, values: newMaterialization,
2962 inverseMapping);
2963 return success();
2964 }
2965 }
2966
2967 InFlightDiagnostic diag = op->emitError()
2968 << "failed to legalize unresolved materialization "
2969 "from "
2970 << inputOperands.getTypes() << " to " << outputType
2971 << " that remained live after conversion";
2972 if (Operation *liveUser = findLiveUser(op->getUsers())) {
2973 diag.attachNote(noteLoc: liveUser->getLoc())
2974 << "see existing live user here: " << *liveUser;
2975 }
2976 return failure();
2977}
2978
2979LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
2980 ConversionPatternRewriter &rewriter,
2981 ConversionPatternRewriterImpl &rewriterImpl,
2982 std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) {
2983 inverseMapping = rewriterImpl.mapping.getInverse();
2984
2985 // As an initial step, compute all of the inserted materializations that we
2986 // expect to persist beyond the conversion process.
2987 DenseMap<Operation *, UnresolvedMaterializationRewrite *> materializationOps;
2988 SetVector<UnresolvedMaterializationRewrite *> necessaryMaterializations;
2989 computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl,
2990 inverseMapping&: *inverseMapping, necessaryMaterializations);
2991
2992 // Once computed, legalize any necessary materializations.
2993 for (auto *mat : necessaryMaterializations) {
2994 if (failed(result: legalizeUnresolvedMaterialization(
2995 mat&: *mat, materializationOps, rewriter, rewriterImpl, inverseMapping&: *inverseMapping)))
2996 return failure();
2997 }
2998 return success();
2999}
3000
3001LogicalResult OperationConverter::legalizeErasedResult(
3002 Operation *op, OpResult result,
3003 ConversionPatternRewriterImpl &rewriterImpl) {
3004 // If the operation result was replaced with null, all of the uses of this
3005 // value should be replaced.
3006 auto liveUserIt = llvm::find_if_not(Range: result.getUsers(), P: [&](Operation *user) {
3007 return rewriterImpl.isOpIgnored(op: user);
3008 });
3009 if (liveUserIt != result.user_end()) {
3010 InFlightDiagnostic diag = op->emitError(message: "failed to legalize operation '")
3011 << op->getName() << "' marked as erased";
3012 diag.attachNote(noteLoc: liveUserIt->getLoc())
3013 << "found live user of result #" << result.getResultNumber() << ": "
3014 << *liveUserIt;
3015 return failure();
3016 }
3017 return success();
3018}
3019
3020/// Finds a user of the given value, or of any other value that the given value
3021/// replaced, that was not replaced in the conversion process.
3022static Operation *findLiveUserOfReplaced(
3023 Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
3024 const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
3025 SmallVector<Value> worklist(1, initialValue);
3026 while (!worklist.empty()) {
3027 Value value = worklist.pop_back_val();
3028
3029 // Walk the users of this value to see if there are any live users that
3030 // weren't replaced during conversion.
3031 auto liveUserIt = llvm::find_if_not(Range: value.getUsers(), P: [&](Operation *user) {
3032 return rewriterImpl.isOpIgnored(op: user);
3033 });
3034 if (liveUserIt != value.user_end())
3035 return *liveUserIt;
3036 auto mapIt = inverseMapping.find(Val: value);
3037 if (mapIt != inverseMapping.end())
3038 worklist.append(RHS: mapIt->second);
3039 }
3040 return nullptr;
3041}
3042
3043LogicalResult OperationConverter::legalizeChangedResultType(
3044 Operation *op, OpResult result, Value newValue,
3045 const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
3046 ConversionPatternRewriterImpl &rewriterImpl,
3047 const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
3048 Operation *liveUser =
3049 findLiveUserOfReplaced(initialValue: result, rewriterImpl, inverseMapping);
3050 if (!liveUser)
3051 return success();
3052
3053 // Functor used to emit a conversion error for a failed materialization.
3054 auto emitConversionError = [&] {
3055 InFlightDiagnostic diag = op->emitError()
3056 << "failed to materialize conversion for result #"
3057 << result.getResultNumber() << " of operation '"
3058 << op->getName()
3059 << "' that remained live after conversion";
3060 diag.attachNote(noteLoc: liveUser->getLoc())
3061 << "see existing live user here: " << *liveUser;
3062 return failure();
3063 };
3064
3065 // If the replacement has a type converter, attempt to materialize a
3066 // conversion back to the original type.
3067 if (!replConverter)
3068 return emitConversionError();
3069
3070 // Materialize a conversion for this live result value.
3071 Type resultType = result.getType();
3072 Value convertedValue = replConverter->materializeSourceConversion(
3073 builder&: rewriter, loc: op->getLoc(), resultType, inputs: newValue);
3074 if (!convertedValue)
3075 return emitConversionError();
3076
3077 rewriterImpl.mapping.map(oldVal: result, newVal: convertedValue);
3078 return success();
3079}
3080
3081//===----------------------------------------------------------------------===//
3082// Type Conversion
3083//===----------------------------------------------------------------------===//
3084
3085void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
3086 ArrayRef<Type> types) {
3087 assert(!types.empty() && "expected valid types");
3088 remapInput(origInputNo, /*newInputNo=*/argTypes.size(), newInputCount: types.size());
3089 addInputs(types);
3090}
3091
3092void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
3093 assert(!types.empty() &&
3094 "1->0 type remappings don't need to be added explicitly");
3095 argTypes.append(in_start: types.begin(), in_end: types.end());
3096}
3097
3098void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
3099 unsigned newInputNo,
3100 unsigned newInputCount) {
3101 assert(!remappedInputs[origInputNo] && "input has already been remapped");
3102 assert(newInputCount != 0 && "expected valid input count");
3103 remappedInputs[origInputNo] =
3104 InputMapping{.inputNo: newInputNo, .size: newInputCount, /*replacementValue=*/nullptr};
3105}
3106
3107void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
3108 Value replacementValue) {
3109 assert(!remappedInputs[origInputNo] && "input has already been remapped");
3110 remappedInputs[origInputNo] =
3111 InputMapping{.inputNo: origInputNo, /*size=*/0, .replacementValue: replacementValue};
3112}
3113
3114LogicalResult TypeConverter::convertType(Type t,
3115 SmallVectorImpl<Type> &results) const {
3116 {
3117 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3118 std::defer_lock);
3119 if (t.getContext()->isMultithreadingEnabled())
3120 cacheReadLock.lock();
3121 auto existingIt = cachedDirectConversions.find(Val: t);
3122 if (existingIt != cachedDirectConversions.end()) {
3123 if (existingIt->second)
3124 results.push_back(Elt: existingIt->second);
3125 return success(isSuccess: existingIt->second != nullptr);
3126 }
3127 auto multiIt = cachedMultiConversions.find(Val: t);
3128 if (multiIt != cachedMultiConversions.end()) {
3129 results.append(in_start: multiIt->second.begin(), in_end: multiIt->second.end());
3130 return success();
3131 }
3132 }
3133 // Walk the added converters in reverse order to apply the most recently
3134 // registered first.
3135 size_t currentCount = results.size();
3136
3137 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3138 std::defer_lock);
3139
3140 for (const ConversionCallbackFn &converter : llvm::reverse(C: conversions)) {
3141 if (std::optional<LogicalResult> result = converter(t, results)) {
3142 if (t.getContext()->isMultithreadingEnabled())
3143 cacheWriteLock.lock();
3144 if (!succeeded(result: *result)) {
3145 cachedDirectConversions.try_emplace(Key: t, Args: nullptr);
3146 return failure();
3147 }
3148 auto newTypes = ArrayRef<Type>(results).drop_front(N: currentCount);
3149 if (newTypes.size() == 1)
3150 cachedDirectConversions.try_emplace(Key: t, Args: newTypes.front());
3151 else
3152 cachedMultiConversions.try_emplace(Key: t, Args: llvm::to_vector<2>(Range&: newTypes));
3153 return success();
3154 }
3155 }
3156 return failure();
3157}
3158
3159Type TypeConverter::convertType(Type t) const {
3160 // Use the multi-type result version to convert the type.
3161 SmallVector<Type, 1> results;
3162 if (failed(result: convertType(t, results)))
3163 return nullptr;
3164
3165 // Check to ensure that only one type was produced.
3166 return results.size() == 1 ? results.front() : nullptr;
3167}
3168
3169LogicalResult
3170TypeConverter::convertTypes(TypeRange types,
3171 SmallVectorImpl<Type> &results) const {
3172 for (Type type : types)
3173 if (failed(result: convertType(t: type, results)))
3174 return failure();
3175 return success();
3176}
3177
3178bool TypeConverter::isLegal(Type type) const {
3179 return convertType(t: type) == type;
3180}
3181bool TypeConverter::isLegal(Operation *op) const {
3182 return isLegal(range: op->getOperandTypes()) && isLegal(range: op->getResultTypes());
3183}
3184
3185bool TypeConverter::isLegal(Region *region) const {
3186 return llvm::all_of(Range&: *region, P: [this](Block &block) {
3187 return isLegal(range: block.getArgumentTypes());
3188 });
3189}
3190
3191bool TypeConverter::isSignatureLegal(FunctionType ty) const {
3192 return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
3193}
3194
3195LogicalResult
3196TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
3197 SignatureConversion &result) const {
3198 // Try to convert the given input type.
3199 SmallVector<Type, 1> convertedTypes;
3200 if (failed(result: convertType(t: type, results&: convertedTypes)))
3201 return failure();
3202
3203 // If this argument is being dropped, there is nothing left to do.
3204 if (convertedTypes.empty())
3205 return success();
3206
3207 // Otherwise, add the new inputs.
3208 result.addInputs(origInputNo: inputNo, types: convertedTypes);
3209 return success();
3210}
3211LogicalResult
3212TypeConverter::convertSignatureArgs(TypeRange types,
3213 SignatureConversion &result,
3214 unsigned origInputOffset) const {
3215 for (unsigned i = 0, e = types.size(); i != e; ++i)
3216 if (failed(result: convertSignatureArg(inputNo: origInputOffset + i, type: types[i], result)))
3217 return failure();
3218 return success();
3219}
3220
3221Value TypeConverter::materializeConversion(
3222 ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder,
3223 Location loc, Type resultType, ValueRange inputs) const {
3224 for (const MaterializationCallbackFn &fn : llvm::reverse(C&: materializations))
3225 if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
3226 return *result;
3227 return nullptr;
3228}
3229
3230std::optional<TypeConverter::SignatureConversion>
3231TypeConverter::convertBlockSignature(Block *block) const {
3232 SignatureConversion conversion(block->getNumArguments());
3233 if (failed(result: convertSignatureArgs(types: block->getArgumentTypes(), result&: conversion)))
3234 return std::nullopt;
3235 return conversion;
3236}
3237
3238//===----------------------------------------------------------------------===//
3239// Type attribute conversion
3240//===----------------------------------------------------------------------===//
3241TypeConverter::AttributeConversionResult
3242TypeConverter::AttributeConversionResult::result(Attribute attr) {
3243 return AttributeConversionResult(attr, resultTag);
3244}
3245
3246TypeConverter::AttributeConversionResult
3247TypeConverter::AttributeConversionResult::na() {
3248 return AttributeConversionResult(nullptr, naTag);
3249}
3250
3251TypeConverter::AttributeConversionResult
3252TypeConverter::AttributeConversionResult::abort() {
3253 return AttributeConversionResult(nullptr, abortTag);
3254}
3255
3256bool TypeConverter::AttributeConversionResult::hasResult() const {
3257 return impl.getInt() == resultTag;
3258}
3259
3260bool TypeConverter::AttributeConversionResult::isNa() const {
3261 return impl.getInt() == naTag;
3262}
3263
3264bool TypeConverter::AttributeConversionResult::isAbort() const {
3265 return impl.getInt() == abortTag;
3266}
3267
3268Attribute TypeConverter::AttributeConversionResult::getResult() const {
3269 assert(hasResult() && "Cannot get result from N/A or abort");
3270 return impl.getPointer();
3271}
3272
3273std::optional<Attribute>
3274TypeConverter::convertTypeAttribute(Type type, Attribute attr) const {
3275 for (const TypeAttributeConversionCallbackFn &fn :
3276 llvm::reverse(C: typeAttributeConversions)) {
3277 AttributeConversionResult res = fn(type, attr);
3278 if (res.hasResult())
3279 return res.getResult();
3280 if (res.isAbort())
3281 return std::nullopt;
3282 }
3283 return std::nullopt;
3284}
3285
3286//===----------------------------------------------------------------------===//
3287// FunctionOpInterfaceSignatureConversion
3288//===----------------------------------------------------------------------===//
3289
3290static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3291 const TypeConverter &typeConverter,
3292 ConversionPatternRewriter &rewriter) {
3293 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3294 if (!type)
3295 return failure();
3296
3297 // Convert the original function types.
3298 TypeConverter::SignatureConversion result(type.getNumInputs());
3299 SmallVector<Type, 1> newResults;
3300 if (failed(typeConverter.convertSignatureArgs(types: type.getInputs(), result)) ||
3301 failed(typeConverter.convertTypes(types: type.getResults(), results&: newResults)) ||
3302 failed(rewriter.convertRegionTypes(region: &funcOp.getFunctionBody(),
3303 converter: typeConverter, entryConversion: &result)))
3304 return failure();
3305
3306 // Update the function signature in-place.
3307 auto newType = FunctionType::get(rewriter.getContext(),
3308 result.getConvertedTypes(), newResults);
3309
3310 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3311
3312 return success();
3313}
3314
3315/// Create a default conversion pattern that rewrites the type signature of a
3316/// FunctionOpInterface op. This only supports ops which use FunctionType to
3317/// represent their type.
3318namespace {
3319struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3320 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3321 MLIRContext *ctx,
3322 const TypeConverter &converter)
3323 : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
3324
3325 LogicalResult
3326 matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3327 ConversionPatternRewriter &rewriter) const override {
3328 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3329 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3330 }
3331};
3332
3333struct AnyFunctionOpInterfaceSignatureConversion
3334 : public OpInterfaceConversionPattern<FunctionOpInterface> {
3335 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3336
3337 LogicalResult
3338 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3339 ConversionPatternRewriter &rewriter) const override {
3340 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3341 }
3342};
3343} // namespace
3344
3345FailureOr<Operation *>
3346mlir::convertOpResultTypes(Operation *op, ValueRange operands,
3347 const TypeConverter &converter,
3348 ConversionPatternRewriter &rewriter) {
3349 assert(op && "Invalid op");
3350 Location loc = op->getLoc();
3351 if (converter.isLegal(op))
3352 return rewriter.notifyMatchFailure(arg&: loc, msg: "op already legal");
3353
3354 OperationState newOp(loc, op->getName());
3355 newOp.addOperands(newOperands: operands);
3356
3357 SmallVector<Type> newResultTypes;
3358 if (failed(result: converter.convertTypes(types: op->getResultTypes(), results&: newResultTypes)))
3359 return rewriter.notifyMatchFailure(arg&: loc, msg: "couldn't convert return types");
3360
3361 newOp.addTypes(newTypes: newResultTypes);
3362 newOp.addAttributes(newAttributes: op->getAttrs());
3363 return rewriter.create(state: newOp);
3364}
3365
3366void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3367 StringRef functionLikeOpName, RewritePatternSet &patterns,
3368 const TypeConverter &converter) {
3369 patterns.add<FunctionOpInterfaceSignatureConversion>(
3370 arg&: functionLikeOpName, args: patterns.getContext(), args: converter);
3371}
3372
3373void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3374 RewritePatternSet &patterns, const TypeConverter &converter) {
3375 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3376 arg: converter, args: patterns.getContext());
3377}
3378
3379//===----------------------------------------------------------------------===//
3380// ConversionTarget
3381//===----------------------------------------------------------------------===//
3382
3383void ConversionTarget::setOpAction(OperationName op,
3384 LegalizationAction action) {
3385 legalOperations[op].action = action;
3386}
3387
3388void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3389 LegalizationAction action) {
3390 for (StringRef dialect : dialectNames)
3391 legalDialects[dialect] = action;
3392}
3393
3394auto ConversionTarget::getOpAction(OperationName op) const
3395 -> std::optional<LegalizationAction> {
3396 std::optional<LegalizationInfo> info = getOpInfo(op);
3397 return info ? info->action : std::optional<LegalizationAction>();
3398}
3399
3400auto ConversionTarget::isLegal(Operation *op) const
3401 -> std::optional<LegalOpDetails> {
3402 std::optional<LegalizationInfo> info = getOpInfo(op: op->getName());
3403 if (!info)
3404 return std::nullopt;
3405
3406 // Returns true if this operation instance is known to be legal.
3407 auto isOpLegal = [&] {
3408 // Handle dynamic legality either with the provided legality function.
3409 if (info->action == LegalizationAction::Dynamic) {
3410 std::optional<bool> result = info->legalityFn(op);
3411 if (result)
3412 return *result;
3413 }
3414
3415 // Otherwise, the operation is only legal if it was marked 'Legal'.
3416 return info->action == LegalizationAction::Legal;
3417 };
3418 if (!isOpLegal())
3419 return std::nullopt;
3420
3421 // This operation is legal, compute any additional legality information.
3422 LegalOpDetails legalityDetails;
3423 if (info->isRecursivelyLegal) {
3424 auto legalityFnIt = opRecursiveLegalityFns.find(Val: op->getName());
3425 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3426 legalityDetails.isRecursivelyLegal =
3427 legalityFnIt->second(op).value_or(u: true);
3428 } else {
3429 legalityDetails.isRecursivelyLegal = true;
3430 }
3431 }
3432 return legalityDetails;
3433}
3434
3435bool ConversionTarget::isIllegal(Operation *op) const {
3436 std::optional<LegalizationInfo> info = getOpInfo(op: op->getName());
3437 if (!info)
3438 return false;
3439
3440 if (info->action == LegalizationAction::Dynamic) {
3441 std::optional<bool> result = info->legalityFn(op);
3442 if (!result)
3443 return false;
3444
3445 return !(*result);
3446 }
3447
3448 return info->action == LegalizationAction::Illegal;
3449}
3450
3451static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(
3452 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
3453 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
3454 if (!oldCallback)
3455 return newCallback;
3456
3457 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3458 Operation *op) -> std::optional<bool> {
3459 if (std::optional<bool> result = newCl(op))
3460 return *result;
3461
3462 return oldCl(op);
3463 };
3464 return chain;
3465}
3466
3467void ConversionTarget::setLegalityCallback(
3468 OperationName name, const DynamicLegalityCallbackFn &callback) {
3469 assert(callback && "expected valid legality callback");
3470 auto *infoIt = legalOperations.find(Key: name);
3471 assert(infoIt != legalOperations.end() &&
3472 infoIt->second.action == LegalizationAction::Dynamic &&
3473 "expected operation to already be marked as dynamically legal");
3474 infoIt->second.legalityFn =
3475 composeLegalityCallbacks(oldCallback: std::move(infoIt->second.legalityFn), newCallback: callback);
3476}
3477
3478void ConversionTarget::markOpRecursivelyLegal(
3479 OperationName name, const DynamicLegalityCallbackFn &callback) {
3480 auto *infoIt = legalOperations.find(Key: name);
3481 assert(infoIt != legalOperations.end() &&
3482 infoIt->second.action != LegalizationAction::Illegal &&
3483 "expected operation to already be marked as legal");
3484 infoIt->second.isRecursivelyLegal = true;
3485 if (callback)
3486 opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3487 oldCallback: std::move(opRecursiveLegalityFns[name]), newCallback: callback);
3488 else
3489 opRecursiveLegalityFns.erase(Val: name);
3490}
3491
3492void ConversionTarget::setLegalityCallback(
3493 ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3494 assert(callback && "expected valid legality callback");
3495 for (StringRef dialect : dialects)
3496 dialectLegalityFns[dialect] = composeLegalityCallbacks(
3497 oldCallback: std::move(dialectLegalityFns[dialect]), newCallback: callback);
3498}
3499
3500void ConversionTarget::setLegalityCallback(
3501 const DynamicLegalityCallbackFn &callback) {
3502 assert(callback && "expected valid legality callback");
3503 unknownLegalityFn = composeLegalityCallbacks(oldCallback: unknownLegalityFn, newCallback: callback);
3504}
3505
3506auto ConversionTarget::getOpInfo(OperationName op) const
3507 -> std::optional<LegalizationInfo> {
3508 // Check for info for this specific operation.
3509 const auto *it = legalOperations.find(Key: op);
3510 if (it != legalOperations.end())
3511 return it->second;
3512 // Check for info for the parent dialect.
3513 auto dialectIt = legalDialects.find(Key: op.getDialectNamespace());
3514 if (dialectIt != legalDialects.end()) {
3515 DynamicLegalityCallbackFn callback;
3516 auto dialectFn = dialectLegalityFns.find(Key: op.getDialectNamespace());
3517 if (dialectFn != dialectLegalityFns.end())
3518 callback = dialectFn->second;
3519 return LegalizationInfo{.action: dialectIt->second, /*isRecursivelyLegal=*/false,
3520 .legalityFn: callback};
3521 }
3522 // Otherwise, check if we mark unknown operations as dynamic.
3523 if (unknownLegalityFn)
3524 return LegalizationInfo{.action: LegalizationAction::Dynamic,
3525 /*isRecursivelyLegal=*/false, .legalityFn: unknownLegalityFn};
3526 return std::nullopt;
3527}
3528
3529#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3530//===----------------------------------------------------------------------===//
3531// PDL Configuration
3532//===----------------------------------------------------------------------===//
3533
3534void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
3535 auto &rewriterImpl =
3536 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3537 rewriterImpl.currentTypeConverter = getTypeConverter();
3538}
3539
3540void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
3541 auto &rewriterImpl =
3542 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3543 rewriterImpl.currentTypeConverter = nullptr;
3544}
3545
3546/// Remap the given value using the rewriter and the type converter in the
3547/// provided config.
3548static FailureOr<SmallVector<Value>>
3549pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) {
3550 SmallVector<Value> mappedValues;
3551 if (failed(result: rewriter.getRemappedValues(keys: values, results&: mappedValues)))
3552 return failure();
3553 return std::move(mappedValues);
3554}
3555
3556void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
3557 patterns.getPDLPatterns().registerRewriteFunction(
3558 name: "convertValue",
3559 rewriteFn: [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3560 auto results = pdllConvertValues(
3561 rewriter&: static_cast<ConversionPatternRewriter &>(rewriter), values: value);
3562 if (failed(result: results))
3563 return failure();
3564 return results->front();
3565 });
3566 patterns.getPDLPatterns().registerRewriteFunction(
3567 name: "convertValues", rewriteFn: [](PatternRewriter &rewriter, ValueRange values) {
3568 return pdllConvertValues(
3569 rewriter&: static_cast<ConversionPatternRewriter &>(rewriter), values);
3570 });
3571 patterns.getPDLPatterns().registerRewriteFunction(
3572 name: "convertType",
3573 rewriteFn: [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3574 auto &rewriterImpl =
3575 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3576 if (const TypeConverter *converter =
3577 rewriterImpl.currentTypeConverter) {
3578 if (Type newType = converter->convertType(t: type))
3579 return newType;
3580 return failure();
3581 }
3582 return type;
3583 });
3584 patterns.getPDLPatterns().registerRewriteFunction(
3585 name: "convertTypes",
3586 rewriteFn: [](PatternRewriter &rewriter,
3587 TypeRange types) -> FailureOr<SmallVector<Type>> {
3588 auto &rewriterImpl =
3589 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3590 const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3591 if (!converter)
3592 return SmallVector<Type>(types);
3593
3594 SmallVector<Type> remappedTypes;
3595 if (failed(result: converter->convertTypes(types, results&: remappedTypes)))
3596 return failure();
3597 return std::move(remappedTypes);
3598 });
3599}
3600#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3601
3602//===----------------------------------------------------------------------===//
3603// Op Conversion Entry Points
3604//===----------------------------------------------------------------------===//
3605
3606//===----------------------------------------------------------------------===//
3607// Partial Conversion
3608
3609LogicalResult mlir::applyPartialConversion(
3610 ArrayRef<Operation *> ops, const ConversionTarget &target,
3611 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3612 OperationConverter opConverter(target, patterns, config,
3613 OpConversionMode::Partial);
3614 return opConverter.convertOperations(ops);
3615}
3616LogicalResult
3617mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
3618 const FrozenRewritePatternSet &patterns,
3619 ConversionConfig config) {
3620 return applyPartialConversion(ops: llvm::ArrayRef(op), target, patterns, config);
3621}
3622
3623//===----------------------------------------------------------------------===//
3624// Full Conversion
3625
3626LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
3627 const ConversionTarget &target,
3628 const FrozenRewritePatternSet &patterns,
3629 ConversionConfig config) {
3630 OperationConverter opConverter(target, patterns, config,
3631 OpConversionMode::Full);
3632 return opConverter.convertOperations(ops);
3633}
3634LogicalResult mlir::applyFullConversion(Operation *op,
3635 const ConversionTarget &target,
3636 const FrozenRewritePatternSet &patterns,
3637 ConversionConfig config) {
3638 return applyFullConversion(ops: llvm::ArrayRef(op), target, patterns, config);
3639}
3640
3641//===----------------------------------------------------------------------===//
3642// Analysis Conversion
3643
3644LogicalResult mlir::applyAnalysisConversion(
3645 ArrayRef<Operation *> ops, ConversionTarget &target,
3646 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3647 OperationConverter opConverter(target, patterns, config,
3648 OpConversionMode::Analysis);
3649 return opConverter.convertOperations(ops);
3650}
3651LogicalResult
3652mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
3653 const FrozenRewritePatternSet &patterns,
3654 ConversionConfig config) {
3655 return applyAnalysisConversion(ops: llvm::ArrayRef(op), target, patterns, config);
3656}
3657

source code of mlir/lib/Transforms/Utils/DialectConversion.cpp