1 | //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Transforms/DialectConversion.h" |
10 | #include "mlir/Config/mlir-config.h" |
11 | #include "mlir/IR/Block.h" |
12 | #include "mlir/IR/Builders.h" |
13 | #include "mlir/IR/BuiltinOps.h" |
14 | #include "mlir/IR/IRMapping.h" |
15 | #include "mlir/IR/Iterators.h" |
16 | #include "mlir/Interfaces/FunctionInterfaces.h" |
17 | #include "mlir/Rewrite/PatternApplicator.h" |
18 | #include "llvm/ADT/ScopeExit.h" |
19 | #include "llvm/ADT/SetVector.h" |
20 | #include "llvm/ADT/SmallPtrSet.h" |
21 | #include "llvm/Support/Debug.h" |
22 | #include "llvm/Support/FormatVariadic.h" |
23 | #include "llvm/Support/SaveAndRestore.h" |
24 | #include "llvm/Support/ScopedPrinter.h" |
25 | #include <optional> |
26 | |
27 | using namespace mlir; |
28 | using namespace mlir::detail; |
29 | |
30 | #define DEBUG_TYPE "dialect-conversion" |
31 | |
32 | /// A utility function to log a successful result for the given reason. |
33 | template <typename... Args> |
34 | static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { |
35 | LLVM_DEBUG({ |
36 | os.unindent(); |
37 | os.startLine() << "} -> SUCCESS" ; |
38 | if (!fmt.empty()) |
39 | os.getOStream() << " : " |
40 | << llvm::formatv(fmt.data(), std::forward<Args>(args)...); |
41 | os.getOStream() << "\n" ; |
42 | }); |
43 | } |
44 | |
45 | /// A utility function to log a failure result for the given reason. |
46 | template <typename... Args> |
47 | static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { |
48 | LLVM_DEBUG({ |
49 | os.unindent(); |
50 | os.startLine() << "} -> FAILURE : " |
51 | << llvm::formatv(fmt.data(), std::forward<Args>(args)...) |
52 | << "\n" ; |
53 | }); |
54 | } |
55 | |
56 | //===----------------------------------------------------------------------===// |
57 | // ConversionValueMapping |
58 | //===----------------------------------------------------------------------===// |
59 | |
60 | namespace { |
61 | /// This class wraps a IRMapping to provide recursive lookup |
62 | /// functionality, i.e. we will traverse if the mapped value also has a mapping. |
63 | struct ConversionValueMapping { |
64 | /// Lookup a mapped value within the map. If a mapping for the provided value |
65 | /// does not exist then return the provided value. If `desiredType` is |
66 | /// non-null, returns the most recently mapped value with that type. If an |
67 | /// operand of that type does not exist, defaults to normal behavior. |
68 | Value lookupOrDefault(Value from, Type desiredType = nullptr) const; |
69 | |
70 | /// Lookup a mapped value within the map, or return null if a mapping does not |
71 | /// exist. If a mapping exists, this follows the same behavior of |
72 | /// `lookupOrDefault`. |
73 | Value lookupOrNull(Value from, Type desiredType = nullptr) const; |
74 | |
75 | /// Map a value to the one provided. |
76 | void map(Value oldVal, Value newVal) { |
77 | LLVM_DEBUG({ |
78 | for (Value it = newVal; it; it = mapping.lookupOrNull(it)) |
79 | assert(it != oldVal && "inserting cyclic mapping" ); |
80 | }); |
81 | mapping.map(from: oldVal, to: newVal); |
82 | } |
83 | |
84 | /// Try to map a value to the one provided. Returns false if a transitive |
85 | /// mapping from the new value to the old value already exists, true if the |
86 | /// map was updated. |
87 | bool tryMap(Value oldVal, Value newVal); |
88 | |
89 | /// Drop the last mapping for the given value. |
90 | void erase(Value value) { mapping.erase(from: value); } |
91 | |
92 | /// Returns the inverse raw value mapping (without recursive query support). |
93 | DenseMap<Value, SmallVector<Value>> getInverse() const { |
94 | DenseMap<Value, SmallVector<Value>> inverse; |
95 | for (auto &it : mapping.getValueMap()) |
96 | inverse[it.second].push_back(Elt: it.first); |
97 | return inverse; |
98 | } |
99 | |
100 | private: |
101 | /// Current value mappings. |
102 | IRMapping mapping; |
103 | }; |
104 | } // namespace |
105 | |
106 | Value ConversionValueMapping::lookupOrDefault(Value from, |
107 | Type desiredType) const { |
108 | // If there was no desired type, simply find the leaf value. |
109 | if (!desiredType) { |
110 | // If this value had a valid mapping, unmap that value as well in the case |
111 | // that it was also replaced. |
112 | while (auto mappedValue = mapping.lookupOrNull(from)) |
113 | from = mappedValue; |
114 | return from; |
115 | } |
116 | |
117 | // Otherwise, try to find the deepest value that has the desired type. |
118 | Value desiredValue; |
119 | do { |
120 | if (from.getType() == desiredType) |
121 | desiredValue = from; |
122 | |
123 | Value mappedValue = mapping.lookupOrNull(from); |
124 | if (!mappedValue) |
125 | break; |
126 | from = mappedValue; |
127 | } while (true); |
128 | |
129 | // If the desired value was found use it, otherwise default to the leaf value. |
130 | return desiredValue ? desiredValue : from; |
131 | } |
132 | |
133 | Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const { |
134 | Value result = lookupOrDefault(from, desiredType); |
135 | if (result == from || (desiredType && result.getType() != desiredType)) |
136 | return nullptr; |
137 | return result; |
138 | } |
139 | |
140 | bool ConversionValueMapping::tryMap(Value oldVal, Value newVal) { |
141 | for (Value it = newVal; it; it = mapping.lookupOrNull(from: it)) |
142 | if (it == oldVal) |
143 | return false; |
144 | map(oldVal, newVal); |
145 | return true; |
146 | } |
147 | |
148 | //===----------------------------------------------------------------------===// |
149 | // Rewriter and Translation State |
150 | //===----------------------------------------------------------------------===// |
151 | namespace { |
152 | /// This class contains a snapshot of the current conversion rewriter state. |
153 | /// This is useful when saving and undoing a set of rewrites. |
154 | struct RewriterState { |
155 | RewriterState(unsigned numRewrites, unsigned numIgnoredOperations, |
156 | unsigned numReplacedOps) |
157 | : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations), |
158 | numReplacedOps(numReplacedOps) {} |
159 | |
160 | /// The current number of rewrites performed. |
161 | unsigned numRewrites; |
162 | |
163 | /// The current number of ignored operations. |
164 | unsigned numIgnoredOperations; |
165 | |
166 | /// The current number of replaced ops that are scheduled for erasure. |
167 | unsigned numReplacedOps; |
168 | }; |
169 | |
170 | //===----------------------------------------------------------------------===// |
171 | // IR rewrites |
172 | //===----------------------------------------------------------------------===// |
173 | |
174 | /// An IR rewrite that can be committed (upon success) or rolled back (upon |
175 | /// failure). |
176 | /// |
177 | /// The dialect conversion keeps track of IR modifications (requested by the |
178 | /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites |
179 | /// are directly applied to the IR as the rewriter API is used, some are applied |
180 | /// partially, and some are delayed until the `IRRewrite` objects are committed. |
181 | class IRRewrite { |
182 | public: |
183 | /// The kind of the rewrite. Rewrites can be undone if the conversion fails. |
184 | /// Enum values are ordered, so that they can be used in `classof`: first all |
185 | /// block rewrites, then all operation rewrites. |
186 | enum class Kind { |
187 | // Block rewrites |
188 | CreateBlock, |
189 | EraseBlock, |
190 | InlineBlock, |
191 | MoveBlock, |
192 | BlockTypeConversion, |
193 | ReplaceBlockArg, |
194 | // Operation rewrites |
195 | MoveOperation, |
196 | ModifyOperation, |
197 | ReplaceOperation, |
198 | CreateOperation, |
199 | UnresolvedMaterialization |
200 | }; |
201 | |
202 | virtual ~IRRewrite() = default; |
203 | |
204 | /// Roll back the rewrite. Operations may be erased during rollback. |
205 | virtual void rollback() = 0; |
206 | |
207 | /// Commit the rewrite. At this point, it is certain that the dialect |
208 | /// conversion will succeed. All IR modifications, except for operation/block |
209 | /// erasure, must be performed through the given rewriter. |
210 | /// |
211 | /// Instead of erasing operations/blocks, they should merely be unlinked |
212 | /// commit phase and finally be erased during the cleanup phase. This is |
213 | /// because internal dialect conversion state (such as `mapping`) may still |
214 | /// be using them. |
215 | /// |
216 | /// Any IR modification that was already performed before the commit phase |
217 | /// (e.g., insertion of an op) must be communicated to the listener that may |
218 | /// be attached to the given rewriter. |
219 | virtual void commit(RewriterBase &rewriter) {} |
220 | |
221 | /// Cleanup operations/blocks. Cleanup is called after commit. |
222 | virtual void cleanup(RewriterBase &rewriter) {} |
223 | |
224 | Kind getKind() const { return kind; } |
225 | |
226 | static bool classof(const IRRewrite *rewrite) { return true; } |
227 | |
228 | protected: |
229 | IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl) |
230 | : kind(kind), rewriterImpl(rewriterImpl) {} |
231 | |
232 | const ConversionConfig &getConfig() const; |
233 | |
234 | const Kind kind; |
235 | ConversionPatternRewriterImpl &rewriterImpl; |
236 | }; |
237 | |
238 | /// A block rewrite. |
239 | class BlockRewrite : public IRRewrite { |
240 | public: |
241 | /// Return the block that this rewrite operates on. |
242 | Block *getBlock() const { return block; } |
243 | |
244 | static bool classof(const IRRewrite *rewrite) { |
245 | return rewrite->getKind() >= Kind::CreateBlock && |
246 | rewrite->getKind() <= Kind::ReplaceBlockArg; |
247 | } |
248 | |
249 | protected: |
250 | BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl, |
251 | Block *block) |
252 | : IRRewrite(kind, rewriterImpl), block(block) {} |
253 | |
254 | // The block that this rewrite operates on. |
255 | Block *block; |
256 | }; |
257 | |
258 | /// Creation of a block. Block creations are immediately reflected in the IR. |
259 | /// There is no extra work to commit the rewrite. During rollback, the newly |
260 | /// created block is erased. |
261 | class CreateBlockRewrite : public BlockRewrite { |
262 | public: |
263 | CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block) |
264 | : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {} |
265 | |
266 | static bool classof(const IRRewrite *rewrite) { |
267 | return rewrite->getKind() == Kind::CreateBlock; |
268 | } |
269 | |
270 | void commit(RewriterBase &rewriter) override { |
271 | // The block was already created and inserted. Just inform the listener. |
272 | if (auto *listener = rewriter.getListener()) |
273 | listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{}); |
274 | } |
275 | |
276 | void rollback() override { |
277 | // Unlink all of the operations within this block, they will be deleted |
278 | // separately. |
279 | auto &blockOps = block->getOperations(); |
280 | while (!blockOps.empty()) |
281 | blockOps.remove(IT: blockOps.begin()); |
282 | block->dropAllUses(); |
283 | if (block->getParent()) |
284 | block->erase(); |
285 | else |
286 | delete block; |
287 | } |
288 | }; |
289 | |
290 | /// Erasure of a block. Block erasures are partially reflected in the IR. Erased |
291 | /// blocks are immediately unlinked, but only erased during cleanup. This makes |
292 | /// it easier to rollback a block erasure: the block is simply inserted into its |
293 | /// original location. |
294 | class EraseBlockRewrite : public BlockRewrite { |
295 | public: |
296 | EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, |
297 | Region *region, Block *insertBeforeBlock) |
298 | : BlockRewrite(Kind::EraseBlock, rewriterImpl, block), region(region), |
299 | insertBeforeBlock(insertBeforeBlock) {} |
300 | |
301 | static bool classof(const IRRewrite *rewrite) { |
302 | return rewrite->getKind() == Kind::EraseBlock; |
303 | } |
304 | |
305 | ~EraseBlockRewrite() override { |
306 | assert(!block && |
307 | "rewrite was neither rolled back nor committed/cleaned up" ); |
308 | } |
309 | |
310 | void rollback() override { |
311 | // The block (owned by this rewrite) was not actually erased yet. It was |
312 | // just unlinked. Put it back into its original position. |
313 | assert(block && "expected block" ); |
314 | auto &blockList = region->getBlocks(); |
315 | Region::iterator before = insertBeforeBlock |
316 | ? Region::iterator(insertBeforeBlock) |
317 | : blockList.end(); |
318 | blockList.insert(where: before, New: block); |
319 | block = nullptr; |
320 | } |
321 | |
322 | void commit(RewriterBase &rewriter) override { |
323 | // Erase the block. |
324 | assert(block && "expected block" ); |
325 | assert(block->empty() && "expected empty block" ); |
326 | |
327 | // Notify the listener that the block is about to be erased. |
328 | if (auto *listener = |
329 | dyn_cast_or_null<RewriterBase::Listener>(Val: rewriter.getListener())) |
330 | listener->notifyBlockErased(block); |
331 | } |
332 | |
333 | void cleanup(RewriterBase &rewriter) override { |
334 | // Erase the block. |
335 | block->dropAllDefinedValueUses(); |
336 | delete block; |
337 | block = nullptr; |
338 | } |
339 | |
340 | private: |
341 | // The region in which this block was previously contained. |
342 | Region *region; |
343 | |
344 | // The original successor of this block before it was unlinked. "nullptr" if |
345 | // this block was the only block in the region. |
346 | Block *insertBeforeBlock; |
347 | }; |
348 | |
349 | /// Inlining of a block. This rewrite is immediately reflected in the IR. |
350 | /// Note: This rewrite represents only the inlining of the operations. The |
351 | /// erasure of the inlined block is a separate rewrite. |
352 | class InlineBlockRewrite : public BlockRewrite { |
353 | public: |
354 | InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, |
355 | Block *sourceBlock, Block::iterator before) |
356 | : BlockRewrite(Kind::InlineBlock, rewriterImpl, block), |
357 | sourceBlock(sourceBlock), |
358 | firstInlinedInst(sourceBlock->empty() ? nullptr |
359 | : &sourceBlock->front()), |
360 | lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) { |
361 | // If a listener is attached to the dialect conversion, ops must be moved |
362 | // one-by-one. When they are moved in bulk, notifications cannot be sent |
363 | // because the ops that used to be in the source block at the time of the |
364 | // inlining (before the "commit" phase) are unknown at the time when |
365 | // notifications are sent (which is during the "commit" phase). |
366 | assert(!getConfig().listener && |
367 | "InlineBlockRewrite not supported if listener is attached" ); |
368 | } |
369 | |
370 | static bool classof(const IRRewrite *rewrite) { |
371 | return rewrite->getKind() == Kind::InlineBlock; |
372 | } |
373 | |
374 | void rollback() override { |
375 | // Put the operations from the destination block (owned by the rewrite) |
376 | // back into the source block. |
377 | if (firstInlinedInst) { |
378 | assert(lastInlinedInst && "expected operation" ); |
379 | sourceBlock->getOperations().splice(where: sourceBlock->begin(), |
380 | L2&: block->getOperations(), |
381 | first: Block::iterator(firstInlinedInst), |
382 | last: ++Block::iterator(lastInlinedInst)); |
383 | } |
384 | } |
385 | |
386 | private: |
387 | // The block that originally contained the operations. |
388 | Block *sourceBlock; |
389 | |
390 | // The first inlined operation. |
391 | Operation *firstInlinedInst; |
392 | |
393 | // The last inlined operation. |
394 | Operation *lastInlinedInst; |
395 | }; |
396 | |
397 | /// Moving of a block. This rewrite is immediately reflected in the IR. |
398 | class MoveBlockRewrite : public BlockRewrite { |
399 | public: |
400 | MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, |
401 | Region *region, Block *insertBeforeBlock) |
402 | : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region), |
403 | insertBeforeBlock(insertBeforeBlock) {} |
404 | |
405 | static bool classof(const IRRewrite *rewrite) { |
406 | return rewrite->getKind() == Kind::MoveBlock; |
407 | } |
408 | |
409 | void commit(RewriterBase &rewriter) override { |
410 | // The block was already moved. Just inform the listener. |
411 | if (auto *listener = rewriter.getListener()) { |
412 | // Note: `previousIt` cannot be passed because this is a delayed |
413 | // notification and iterators into past IR state cannot be represented. |
414 | listener->notifyBlockInserted(block, /*previous=*/region, |
415 | /*previousIt=*/{}); |
416 | } |
417 | } |
418 | |
419 | void rollback() override { |
420 | // Move the block back to its original position. |
421 | Region::iterator before = |
422 | insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end(); |
423 | region->getBlocks().splice(where: before, L2&: block->getParent()->getBlocks(), N: block); |
424 | } |
425 | |
426 | private: |
427 | // The region in which this block was previously contained. |
428 | Region *region; |
429 | |
430 | // The original successor of this block before it was moved. "nullptr" if |
431 | // this block was the only block in the region. |
432 | Block *insertBeforeBlock; |
433 | }; |
434 | |
435 | /// This structure contains the information pertaining to an argument that has |
436 | /// been converted. |
437 | struct ConvertedArgInfo { |
438 | ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize, |
439 | Value castValue = nullptr) |
440 | : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {} |
441 | |
442 | /// The start index of in the new argument list that contains arguments that |
443 | /// replace the original. |
444 | unsigned newArgIdx; |
445 | |
446 | /// The number of arguments that replaced the original argument. |
447 | unsigned newArgSize; |
448 | |
449 | /// The cast value that was created to cast from the new arguments to the |
450 | /// old. This only used if 'newArgSize' > 1. |
451 | Value castValue; |
452 | }; |
453 | |
454 | /// Block type conversion. This rewrite is partially reflected in the IR. |
455 | class BlockTypeConversionRewrite : public BlockRewrite { |
456 | public: |
457 | BlockTypeConversionRewrite( |
458 | ConversionPatternRewriterImpl &rewriterImpl, Block *block, |
459 | Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo, |
460 | const TypeConverter *converter) |
461 | : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block), |
462 | origBlock(origBlock), argInfo(argInfo), converter(converter) {} |
463 | |
464 | static bool classof(const IRRewrite *rewrite) { |
465 | return rewrite->getKind() == Kind::BlockTypeConversion; |
466 | } |
467 | |
468 | /// Materialize any necessary conversions for converted arguments that have |
469 | /// live users, using the provided `findLiveUser` to search for a user that |
470 | /// survives the conversion process. |
471 | LogicalResult |
472 | materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser); |
473 | |
474 | void commit(RewriterBase &rewriter) override; |
475 | |
476 | void rollback() override; |
477 | |
478 | private: |
479 | /// The original block that was requested to have its signature converted. |
480 | Block *origBlock; |
481 | |
482 | /// The conversion information for each of the arguments. The information is |
483 | /// std::nullopt if the argument was dropped during conversion. |
484 | SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo; |
485 | |
486 | /// The type converter used to convert the arguments. |
487 | const TypeConverter *converter; |
488 | }; |
489 | |
490 | /// Replacing a block argument. This rewrite is not immediately reflected in the |
491 | /// IR. An internal IR mapping is updated, but the actual replacement is delayed |
492 | /// until the rewrite is committed. |
493 | class ReplaceBlockArgRewrite : public BlockRewrite { |
494 | public: |
495 | ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl, |
496 | Block *block, BlockArgument arg) |
497 | : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {} |
498 | |
499 | static bool classof(const IRRewrite *rewrite) { |
500 | return rewrite->getKind() == Kind::ReplaceBlockArg; |
501 | } |
502 | |
503 | void commit(RewriterBase &rewriter) override; |
504 | |
505 | void rollback() override; |
506 | |
507 | private: |
508 | BlockArgument arg; |
509 | }; |
510 | |
511 | /// An operation rewrite. |
512 | class OperationRewrite : public IRRewrite { |
513 | public: |
514 | /// Return the operation that this rewrite operates on. |
515 | Operation *getOperation() const { return op; } |
516 | |
517 | static bool classof(const IRRewrite *rewrite) { |
518 | return rewrite->getKind() >= Kind::MoveOperation && |
519 | rewrite->getKind() <= Kind::UnresolvedMaterialization; |
520 | } |
521 | |
522 | protected: |
523 | OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl, |
524 | Operation *op) |
525 | : IRRewrite(kind, rewriterImpl), op(op) {} |
526 | |
527 | // The operation that this rewrite operates on. |
528 | Operation *op; |
529 | }; |
530 | |
531 | /// Moving of an operation. This rewrite is immediately reflected in the IR. |
532 | class MoveOperationRewrite : public OperationRewrite { |
533 | public: |
534 | MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, |
535 | Operation *op, Block *block, Operation *insertBeforeOp) |
536 | : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block), |
537 | insertBeforeOp(insertBeforeOp) {} |
538 | |
539 | static bool classof(const IRRewrite *rewrite) { |
540 | return rewrite->getKind() == Kind::MoveOperation; |
541 | } |
542 | |
543 | void commit(RewriterBase &rewriter) override { |
544 | // The operation was already moved. Just inform the listener. |
545 | if (auto *listener = rewriter.getListener()) { |
546 | // Note: `previousIt` cannot be passed because this is a delayed |
547 | // notification and iterators into past IR state cannot be represented. |
548 | listener->notifyOperationInserted( |
549 | op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block, |
550 | /*insertPt=*/{})); |
551 | } |
552 | } |
553 | |
554 | void rollback() override { |
555 | // Move the operation back to its original position. |
556 | Block::iterator before = |
557 | insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end(); |
558 | block->getOperations().splice(where: before, L2&: op->getBlock()->getOperations(), N: op); |
559 | } |
560 | |
561 | private: |
562 | // The block in which this operation was previously contained. |
563 | Block *block; |
564 | |
565 | // The original successor of this operation before it was moved. "nullptr" |
566 | // if this operation was the only operation in the region. |
567 | Operation *insertBeforeOp; |
568 | }; |
569 | |
570 | /// In-place modification of an op. This rewrite is immediately reflected in |
571 | /// the IR. The previous state of the operation is stored in this object. |
572 | class ModifyOperationRewrite : public OperationRewrite { |
573 | public: |
574 | ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, |
575 | Operation *op) |
576 | : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op), |
577 | name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()), |
578 | operands(op->operand_begin(), op->operand_end()), |
579 | successors(op->successor_begin(), op->successor_end()) { |
580 | if (OpaqueProperties prop = op->getPropertiesStorage()) { |
581 | // Make a copy of the properties. |
582 | propertiesStorage = operator new(op->getPropertiesStorageSize()); |
583 | OpaqueProperties propCopy(propertiesStorage); |
584 | name.initOpProperties(storage: propCopy, /*init=*/prop); |
585 | } |
586 | } |
587 | |
588 | static bool classof(const IRRewrite *rewrite) { |
589 | return rewrite->getKind() == Kind::ModifyOperation; |
590 | } |
591 | |
592 | ~ModifyOperationRewrite() override { |
593 | assert(!propertiesStorage && |
594 | "rewrite was neither committed nor rolled back" ); |
595 | } |
596 | |
597 | void commit(RewriterBase &rewriter) override { |
598 | // Notify the listener that the operation was modified in-place. |
599 | if (auto *listener = |
600 | dyn_cast_or_null<RewriterBase::Listener>(Val: rewriter.getListener())) |
601 | listener->notifyOperationModified(op); |
602 | |
603 | if (propertiesStorage) { |
604 | OpaqueProperties propCopy(propertiesStorage); |
605 | // Note: The operation may have been erased in the mean time, so |
606 | // OperationName must be stored in this object. |
607 | name.destroyOpProperties(properties: propCopy); |
608 | operator delete(propertiesStorage); |
609 | propertiesStorage = nullptr; |
610 | } |
611 | } |
612 | |
613 | void rollback() override { |
614 | op->setLoc(loc); |
615 | op->setAttrs(attrs); |
616 | op->setOperands(operands); |
617 | for (const auto &it : llvm::enumerate(First&: successors)) |
618 | op->setSuccessor(block: it.value(), index: it.index()); |
619 | if (propertiesStorage) { |
620 | OpaqueProperties propCopy(propertiesStorage); |
621 | op->copyProperties(rhs: propCopy); |
622 | name.destroyOpProperties(properties: propCopy); |
623 | operator delete(propertiesStorage); |
624 | propertiesStorage = nullptr; |
625 | } |
626 | } |
627 | |
628 | private: |
629 | OperationName name; |
630 | LocationAttr loc; |
631 | DictionaryAttr attrs; |
632 | SmallVector<Value, 8> operands; |
633 | SmallVector<Block *, 2> successors; |
634 | void *propertiesStorage = nullptr; |
635 | }; |
636 | |
637 | /// Replacing an operation. Erasing an operation is treated as a special case |
638 | /// with "null" replacements. This rewrite is not immediately reflected in the |
639 | /// IR. An internal IR mapping is updated, but values are not replaced and the |
640 | /// original op is not erased until the rewrite is committed. |
641 | class ReplaceOperationRewrite : public OperationRewrite { |
642 | public: |
643 | ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, |
644 | Operation *op, const TypeConverter *converter, |
645 | bool changedResults) |
646 | : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op), |
647 | converter(converter), changedResults(changedResults) {} |
648 | |
649 | static bool classof(const IRRewrite *rewrite) { |
650 | return rewrite->getKind() == Kind::ReplaceOperation; |
651 | } |
652 | |
653 | void commit(RewriterBase &rewriter) override; |
654 | |
655 | void rollback() override; |
656 | |
657 | void cleanup(RewriterBase &rewriter) override; |
658 | |
659 | const TypeConverter *getConverter() const { return converter; } |
660 | |
661 | bool hasChangedResults() const { return changedResults; } |
662 | |
663 | private: |
664 | /// An optional type converter that can be used to materialize conversions |
665 | /// between the new and old values if necessary. |
666 | const TypeConverter *converter; |
667 | |
668 | /// A boolean flag that indicates whether result types have changed or not. |
669 | bool changedResults; |
670 | }; |
671 | |
672 | class CreateOperationRewrite : public OperationRewrite { |
673 | public: |
674 | CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, |
675 | Operation *op) |
676 | : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {} |
677 | |
678 | static bool classof(const IRRewrite *rewrite) { |
679 | return rewrite->getKind() == Kind::CreateOperation; |
680 | } |
681 | |
682 | void commit(RewriterBase &rewriter) override { |
683 | // The operation was already created and inserted. Just inform the listener. |
684 | if (auto *listener = rewriter.getListener()) |
685 | listener->notifyOperationInserted(op, /*previous=*/{}); |
686 | } |
687 | |
688 | void rollback() override; |
689 | }; |
690 | |
691 | /// The type of materialization. |
692 | enum MaterializationKind { |
693 | /// This materialization materializes a conversion for an illegal block |
694 | /// argument type, to a legal one. |
695 | Argument, |
696 | |
697 | /// This materialization materializes a conversion from an illegal type to a |
698 | /// legal one. |
699 | Target |
700 | }; |
701 | |
702 | /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast" |
703 | /// op. Unresolved materializations are erased at the end of the dialect |
704 | /// conversion. |
705 | class UnresolvedMaterializationRewrite : public OperationRewrite { |
706 | public: |
707 | UnresolvedMaterializationRewrite( |
708 | ConversionPatternRewriterImpl &rewriterImpl, |
709 | UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr, |
710 | MaterializationKind kind = MaterializationKind::Target, |
711 | Type origOutputType = nullptr) |
712 | : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), |
713 | converterAndKind(converter, kind), origOutputType(origOutputType) {} |
714 | |
715 | static bool classof(const IRRewrite *rewrite) { |
716 | return rewrite->getKind() == Kind::UnresolvedMaterialization; |
717 | } |
718 | |
719 | UnrealizedConversionCastOp getOperation() const { |
720 | return cast<UnrealizedConversionCastOp>(op); |
721 | } |
722 | |
723 | void rollback() override; |
724 | |
725 | void cleanup(RewriterBase &rewriter) override; |
726 | |
727 | /// Return the type converter of this materialization (which may be null). |
728 | const TypeConverter *getConverter() const { |
729 | return converterAndKind.getPointer(); |
730 | } |
731 | |
732 | /// Return the kind of this materialization. |
733 | MaterializationKind getMaterializationKind() const { |
734 | return converterAndKind.getInt(); |
735 | } |
736 | |
737 | /// Set the kind of this materialization. |
738 | void setMaterializationKind(MaterializationKind kind) { |
739 | converterAndKind.setInt(kind); |
740 | } |
741 | |
742 | /// Return the original illegal output type of the input values. |
743 | Type getOrigOutputType() const { return origOutputType; } |
744 | |
745 | private: |
746 | /// The corresponding type converter to use when resolving this |
747 | /// materialization, and the kind of this materialization. |
748 | llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind> |
749 | converterAndKind; |
750 | |
751 | /// The original output type. This is only used for argument conversions. |
752 | Type origOutputType; |
753 | }; |
754 | } // namespace |
755 | |
756 | /// Return "true" if there is an operation rewrite that matches the specified |
757 | /// rewrite type and operation among the given rewrites. |
758 | template <typename RewriteTy, typename R> |
759 | static bool hasRewrite(R &&rewrites, Operation *op) { |
760 | return any_of(std::forward<R>(rewrites), [&](auto &rewrite) { |
761 | auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get()); |
762 | return rewriteTy && rewriteTy->getOperation() == op; |
763 | }); |
764 | } |
765 | |
766 | /// Find the single rewrite object of the specified type and block among the |
767 | /// given rewrites. In debug mode, asserts that there is mo more than one such |
768 | /// object. Return "nullptr" if no object was found. |
769 | template <typename RewriteTy, typename R> |
770 | static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) { |
771 | RewriteTy *result = nullptr; |
772 | for (auto &rewrite : rewrites) { |
773 | auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get()); |
774 | if (rewriteTy && rewriteTy->getBlock() == block) { |
775 | #ifndef NDEBUG |
776 | assert(!result && "expected single matching rewrite" ); |
777 | result = rewriteTy; |
778 | #else |
779 | return rewriteTy; |
780 | #endif // NDEBUG |
781 | } |
782 | } |
783 | return result; |
784 | } |
785 | |
786 | //===----------------------------------------------------------------------===// |
787 | // ConversionPatternRewriterImpl |
788 | //===----------------------------------------------------------------------===// |
789 | namespace mlir { |
790 | namespace detail { |
791 | struct ConversionPatternRewriterImpl : public RewriterBase::Listener { |
792 | explicit ConversionPatternRewriterImpl(MLIRContext *ctx, |
793 | const ConversionConfig &config) |
794 | : context(ctx), config(config) {} |
795 | |
796 | //===--------------------------------------------------------------------===// |
797 | // State Management |
798 | //===--------------------------------------------------------------------===// |
799 | |
800 | /// Return the current state of the rewriter. |
801 | RewriterState getCurrentState(); |
802 | |
803 | /// Apply all requested operation rewrites. This method is invoked when the |
804 | /// conversion process succeeds. |
805 | void applyRewrites(); |
806 | |
807 | /// Reset the state of the rewriter to a previously saved point. |
808 | void resetState(RewriterState state); |
809 | |
810 | /// Append a rewrite. Rewrites are committed upon success and rolled back upon |
811 | /// failure. |
812 | template <typename RewriteTy, typename... Args> |
813 | void appendRewrite(Args &&...args) { |
814 | rewrites.push_back( |
815 | std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...)); |
816 | } |
817 | |
818 | /// Undo the rewrites (motions, splits) one by one in reverse order until |
819 | /// "numRewritesToKeep" rewrites remains. |
820 | void undoRewrites(unsigned numRewritesToKeep = 0); |
821 | |
822 | /// Remap the given values to those with potentially different types. Returns |
823 | /// success if the values could be remapped, failure otherwise. `valueDiagTag` |
824 | /// is the tag used when describing a value within a diagnostic, e.g. |
825 | /// "operand". |
826 | LogicalResult remapValues(StringRef valueDiagTag, |
827 | std::optional<Location> inputLoc, |
828 | PatternRewriter &rewriter, ValueRange values, |
829 | SmallVectorImpl<Value> &remapped); |
830 | |
831 | /// Return "true" if the given operation is ignored, and does not need to be |
832 | /// converted. |
833 | bool isOpIgnored(Operation *op) const; |
834 | |
835 | /// Return "true" if the given operation was replaced or erased. |
836 | bool wasOpReplaced(Operation *op) const; |
837 | |
838 | //===--------------------------------------------------------------------===// |
839 | // Type Conversion |
840 | //===--------------------------------------------------------------------===// |
841 | |
842 | /// Attempt to convert the signature of the given block, if successful a new |
843 | /// block is returned containing the new arguments. Returns `block` if it did |
844 | /// not require conversion. |
845 | FailureOr<Block *> convertBlockSignature( |
846 | ConversionPatternRewriter &rewriter, Block *block, |
847 | const TypeConverter *converter, |
848 | TypeConverter::SignatureConversion *conversion = nullptr); |
849 | |
850 | /// Convert the types of non-entry block arguments within the given region. |
851 | LogicalResult convertNonEntryRegionTypes( |
852 | ConversionPatternRewriter &rewriter, Region *region, |
853 | const TypeConverter &converter, |
854 | ArrayRef<TypeConverter::SignatureConversion> blockConversions = {}); |
855 | |
856 | /// Apply a signature conversion on the given region, using `converter` for |
857 | /// materializations if not null. |
858 | Block * |
859 | applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region, |
860 | TypeConverter::SignatureConversion &conversion, |
861 | const TypeConverter *converter); |
862 | |
863 | /// Convert the types of block arguments within the given region. |
864 | FailureOr<Block *> |
865 | convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, |
866 | const TypeConverter &converter, |
867 | TypeConverter::SignatureConversion *entryConversion); |
868 | |
869 | /// Apply the given signature conversion on the given block. The new block |
870 | /// containing the updated signature is returned. If no conversions were |
871 | /// necessary, e.g. if the block has no arguments, `block` is returned. |
872 | /// `converter` is used to generate any necessary cast operations that |
873 | /// translate between the origin argument types and those specified in the |
874 | /// signature conversion. |
875 | Block *applySignatureConversion( |
876 | ConversionPatternRewriter &rewriter, Block *block, |
877 | const TypeConverter *converter, |
878 | TypeConverter::SignatureConversion &signatureConversion); |
879 | |
880 | //===--------------------------------------------------------------------===// |
881 | // Materializations |
882 | //===--------------------------------------------------------------------===// |
883 | /// Build an unresolved materialization operation given an output type and set |
884 | /// of input operands. |
885 | Value buildUnresolvedMaterialization(MaterializationKind kind, |
886 | Block *insertBlock, |
887 | Block::iterator insertPt, Location loc, |
888 | ValueRange inputs, Type outputType, |
889 | Type origOutputType, |
890 | const TypeConverter *converter); |
891 | |
892 | Value buildUnresolvedArgumentMaterialization(Block *block, Location loc, |
893 | ValueRange inputs, |
894 | Type origOutputType, |
895 | Type outputType, |
896 | const TypeConverter *converter); |
897 | |
898 | Value buildUnresolvedTargetMaterialization(Location loc, Value input, |
899 | Type outputType, |
900 | const TypeConverter *converter); |
901 | |
902 | //===--------------------------------------------------------------------===// |
903 | // Rewriter Notification Hooks |
904 | //===--------------------------------------------------------------------===// |
905 | |
906 | //// Notifies that an op was inserted. |
907 | void notifyOperationInserted(Operation *op, |
908 | OpBuilder::InsertPoint previous) override; |
909 | |
910 | /// Notifies that an op is about to be replaced with the given values. |
911 | void notifyOpReplaced(Operation *op, ValueRange newValues); |
912 | |
913 | /// Notifies that a block is about to be erased. |
914 | void notifyBlockIsBeingErased(Block *block); |
915 | |
916 | /// Notifies that a block was inserted. |
917 | void notifyBlockInserted(Block *block, Region *previous, |
918 | Region::iterator previousIt) override; |
919 | |
920 | /// Notifies that a block is being inlined into another block. |
921 | void notifyBlockBeingInlined(Block *block, Block *srcBlock, |
922 | Block::iterator before); |
923 | |
924 | /// Notifies that a pattern match failed for the given reason. |
925 | void |
926 | notifyMatchFailure(Location loc, |
927 | function_ref<void(Diagnostic &)> reasonCallback) override; |
928 | |
929 | //===--------------------------------------------------------------------===// |
930 | // IR Erasure |
931 | //===--------------------------------------------------------------------===// |
932 | |
933 | /// A rewriter that keeps track of erased ops and blocks. It ensures that no |
934 | /// operation or block is erased multiple times. This rewriter assumes that |
935 | /// no new IR is created between calls to `eraseOp`/`eraseBlock`. |
936 | struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener { |
937 | public: |
938 | SingleEraseRewriter(MLIRContext *context) |
939 | : RewriterBase(context, /*listener=*/this) {} |
940 | |
941 | /// Erase the given op (unless it was already erased). |
942 | void eraseOp(Operation *op) override { |
943 | if (erased.contains(V: op)) |
944 | return; |
945 | op->dropAllUses(); |
946 | RewriterBase::eraseOp(op); |
947 | } |
948 | |
949 | /// Erase the given block (unless it was already erased). |
950 | void eraseBlock(Block *block) override { |
951 | if (erased.contains(V: block)) |
952 | return; |
953 | assert(block->empty() && "expected empty block" ); |
954 | block->dropAllDefinedValueUses(); |
955 | RewriterBase::eraseBlock(block); |
956 | } |
957 | |
958 | void notifyOperationErased(Operation *op) override { erased.insert(V: op); } |
959 | |
960 | void notifyBlockErased(Block *block) override { erased.insert(V: block); } |
961 | |
962 | /// Pointers to all erased operations and blocks. |
963 | DenseSet<void *> erased; |
964 | }; |
965 | |
966 | //===--------------------------------------------------------------------===// |
967 | // State |
968 | //===--------------------------------------------------------------------===// |
969 | |
970 | /// MLIR context. |
971 | MLIRContext *context; |
972 | |
973 | // Mapping between replaced values that differ in type. This happens when |
974 | // replacing a value with one of a different type. |
975 | ConversionValueMapping mapping; |
976 | |
977 | /// Ordered list of block operations (creations, splits, motions). |
978 | SmallVector<std::unique_ptr<IRRewrite>> rewrites; |
979 | |
980 | /// A set of operations that should no longer be considered for legalization. |
981 | /// E.g., ops that are recursively legal. Ops that were replaced/erased are |
982 | /// tracked separately. |
983 | SetVector<Operation *> ignoredOps; |
984 | |
985 | /// A set of operations that were replaced/erased. Such ops are not erased |
986 | /// immediately but only when the dialect conversion succeeds. In the mean |
987 | /// time, they should no longer be considered for legalization and any attempt |
988 | /// to modify/access them is invalid rewriter API usage. |
989 | SetVector<Operation *> replacedOps; |
990 | |
991 | /// The current type converter, or nullptr if no type converter is currently |
992 | /// active. |
993 | const TypeConverter *currentTypeConverter = nullptr; |
994 | |
995 | /// A mapping of regions to type converters that should be used when |
996 | /// converting the arguments of blocks within that region. |
997 | DenseMap<Region *, const TypeConverter *> regionToConverter; |
998 | |
999 | /// Dialect conversion configuration. |
1000 | const ConversionConfig &config; |
1001 | |
1002 | #ifndef NDEBUG |
1003 | /// A set of operations that have pending updates. This tracking isn't |
1004 | /// strictly necessary, and is thus only active during debug builds for extra |
1005 | /// verification. |
1006 | SmallPtrSet<Operation *, 1> pendingRootUpdates; |
1007 | |
1008 | /// A logger used to emit diagnostics during the conversion process. |
1009 | llvm::ScopedPrinter logger{llvm::dbgs()}; |
1010 | #endif |
1011 | }; |
1012 | } // namespace detail |
1013 | } // namespace mlir |
1014 | |
1015 | const ConversionConfig &IRRewrite::getConfig() const { |
1016 | return rewriterImpl.config; |
1017 | } |
1018 | |
1019 | void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) { |
1020 | // Inform the listener about all IR modifications that have already taken |
1021 | // place: References to the original block have been replaced with the new |
1022 | // block. |
1023 | if (auto *listener = |
1024 | dyn_cast_or_null<RewriterBase::Listener>(Val: rewriter.getListener())) |
1025 | for (Operation *op : block->getUsers()) |
1026 | listener->notifyOperationModified(op); |
1027 | |
1028 | // Process the remapping for each of the original arguments. |
1029 | for (auto [origArg, info] : |
1030 | llvm::zip_equal(t: origBlock->getArguments(), u&: argInfo)) { |
1031 | // Handle the case of a 1->0 value mapping. |
1032 | if (!info) { |
1033 | if (Value newArg = |
1034 | rewriterImpl.mapping.lookupOrNull(from: origArg, desiredType: origArg.getType())) |
1035 | rewriter.replaceAllUsesWith(from: origArg, to: newArg); |
1036 | continue; |
1037 | } |
1038 | |
1039 | // Otherwise this is a 1->1+ value mapping. |
1040 | Value castValue = info->castValue; |
1041 | assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping" ); |
1042 | |
1043 | // If the argument is still used, replace it with the generated cast. |
1044 | if (!origArg.use_empty()) { |
1045 | rewriter.replaceAllUsesWith(from: origArg, to: rewriterImpl.mapping.lookupOrDefault( |
1046 | from: castValue, desiredType: origArg.getType())); |
1047 | } |
1048 | } |
1049 | } |
1050 | |
1051 | void BlockTypeConversionRewrite::rollback() { |
1052 | block->replaceAllUsesWith(newValue&: origBlock); |
1053 | } |
1054 | |
1055 | LogicalResult BlockTypeConversionRewrite::materializeLiveConversions( |
1056 | function_ref<Operation *(Value)> findLiveUser) { |
1057 | // Process the remapping for each of the original arguments. |
1058 | for (auto it : llvm::enumerate(First: origBlock->getArguments())) { |
1059 | BlockArgument origArg = it.value(); |
1060 | // Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used. |
1061 | OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl); |
1062 | builder.setInsertionPointToStart(block); |
1063 | |
1064 | // If the type of this argument changed and the argument is still live, we |
1065 | // need to materialize a conversion. |
1066 | if (rewriterImpl.mapping.lookupOrNull(from: origArg, desiredType: origArg.getType())) |
1067 | continue; |
1068 | Operation *liveUser = findLiveUser(origArg); |
1069 | if (!liveUser) |
1070 | continue; |
1071 | |
1072 | Value replacementValue = rewriterImpl.mapping.lookupOrDefault(from: origArg); |
1073 | bool isDroppedArg = replacementValue == origArg; |
1074 | if (!isDroppedArg) |
1075 | builder.setInsertionPointAfterValue(replacementValue); |
1076 | Value newArg; |
1077 | if (converter) { |
1078 | newArg = converter->materializeSourceConversion( |
1079 | builder, loc: origArg.getLoc(), resultType: origArg.getType(), |
1080 | inputs: isDroppedArg ? ValueRange() : ValueRange(replacementValue)); |
1081 | assert((!newArg || newArg.getType() == origArg.getType()) && |
1082 | "materialization hook did not provide a value of the expected " |
1083 | "type" ); |
1084 | } |
1085 | if (!newArg) { |
1086 | InFlightDiagnostic diag = |
1087 | emitError(loc: origArg.getLoc()) |
1088 | << "failed to materialize conversion for block argument #" |
1089 | << it.index() << " that remained live after conversion, type was " |
1090 | << origArg.getType(); |
1091 | if (!isDroppedArg) |
1092 | diag << ", with target type " << replacementValue.getType(); |
1093 | diag.attachNote(noteLoc: liveUser->getLoc()) |
1094 | << "see existing live user here: " << *liveUser; |
1095 | return failure(); |
1096 | } |
1097 | rewriterImpl.mapping.map(oldVal: origArg, newVal: newArg); |
1098 | } |
1099 | return success(); |
1100 | } |
1101 | |
1102 | void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { |
1103 | Value repl = rewriterImpl.mapping.lookupOrNull(from: arg, desiredType: arg.getType()); |
1104 | if (!repl) |
1105 | return; |
1106 | |
1107 | if (isa<BlockArgument>(Val: repl)) { |
1108 | rewriter.replaceAllUsesWith(from: arg, to: repl); |
1109 | return; |
1110 | } |
1111 | |
1112 | // If the replacement value is an operation, we check to make sure that we |
1113 | // don't replace uses that are within the parent operation of the |
1114 | // replacement value. |
1115 | Operation *replOp = cast<OpResult>(Val&: repl).getOwner(); |
1116 | Block *replBlock = replOp->getBlock(); |
1117 | rewriter.replaceUsesWithIf(from: arg, to: repl, functor: [&](OpOperand &operand) { |
1118 | Operation *user = operand.getOwner(); |
1119 | return user->getBlock() != replBlock || replOp->isBeforeInBlock(other: user); |
1120 | }); |
1121 | } |
1122 | |
1123 | void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(value: arg); } |
1124 | |
1125 | void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { |
1126 | auto *listener = |
1127 | dyn_cast_or_null<RewriterBase::Listener>(Val: rewriter.getListener()); |
1128 | |
1129 | // Compute replacement values. |
1130 | SmallVector<Value> replacements = |
1131 | llvm::map_to_vector(C: op->getResults(), F: [&](OpResult result) { |
1132 | return rewriterImpl.mapping.lookupOrNull(from: result, desiredType: result.getType()); |
1133 | }); |
1134 | |
1135 | // Notify the listener that the operation is about to be replaced. |
1136 | if (listener) |
1137 | listener->notifyOperationReplaced(op, replacement: replacements); |
1138 | |
1139 | // Replace all uses with the new values. |
1140 | for (auto [result, newValue] : |
1141 | llvm::zip_equal(t: op->getResults(), u&: replacements)) |
1142 | if (newValue) |
1143 | rewriter.replaceAllUsesWith(from: result, to: newValue); |
1144 | |
1145 | // The original op will be erased, so remove it from the set of unlegalized |
1146 | // ops. |
1147 | if (getConfig().unlegalizedOps) |
1148 | getConfig().unlegalizedOps->erase(V: op); |
1149 | |
1150 | // Notify the listener that the operation (and its nested operations) was |
1151 | // erased. |
1152 | if (listener) { |
1153 | op->walk<WalkOrder::PostOrder>( |
1154 | callback: [&](Operation *op) { listener->notifyOperationErased(op); }); |
1155 | } |
1156 | |
1157 | // Do not erase the operation yet. It may still be referenced in `mapping`. |
1158 | // Just unlink it for now and erase it during cleanup. |
1159 | op->getBlock()->getOperations().remove(IT: op); |
1160 | } |
1161 | |
1162 | void ReplaceOperationRewrite::rollback() { |
1163 | for (auto result : op->getResults()) |
1164 | rewriterImpl.mapping.erase(value: result); |
1165 | } |
1166 | |
1167 | void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) { |
1168 | rewriter.eraseOp(op); |
1169 | } |
1170 | |
1171 | void CreateOperationRewrite::rollback() { |
1172 | for (Region ®ion : op->getRegions()) { |
1173 | while (!region.getBlocks().empty()) |
1174 | region.getBlocks().remove(IT: region.getBlocks().begin()); |
1175 | } |
1176 | op->dropAllUses(); |
1177 | op->erase(); |
1178 | } |
1179 | |
1180 | void UnresolvedMaterializationRewrite::rollback() { |
1181 | if (getMaterializationKind() == MaterializationKind::Target) { |
1182 | for (Value input : op->getOperands()) |
1183 | rewriterImpl.mapping.erase(value: input); |
1184 | } |
1185 | op->erase(); |
1186 | } |
1187 | |
1188 | void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) { |
1189 | rewriter.eraseOp(op); |
1190 | } |
1191 | |
1192 | void ConversionPatternRewriterImpl::applyRewrites() { |
1193 | // Commit all rewrites. |
1194 | IRRewriter rewriter(context, config.listener); |
1195 | for (auto &rewrite : rewrites) |
1196 | rewrite->commit(rewriter); |
1197 | |
1198 | // Clean up all rewrites. |
1199 | SingleEraseRewriter eraseRewriter(context); |
1200 | for (auto &rewrite : rewrites) |
1201 | rewrite->cleanup(rewriter&: eraseRewriter); |
1202 | } |
1203 | |
1204 | //===----------------------------------------------------------------------===// |
1205 | // State Management |
1206 | |
1207 | RewriterState ConversionPatternRewriterImpl::getCurrentState() { |
1208 | return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size()); |
1209 | } |
1210 | |
1211 | void ConversionPatternRewriterImpl::resetState(RewriterState state) { |
1212 | // Undo any rewrites. |
1213 | undoRewrites(numRewritesToKeep: state.numRewrites); |
1214 | |
1215 | // Pop all of the recorded ignored operations that are no longer valid. |
1216 | while (ignoredOps.size() != state.numIgnoredOperations) |
1217 | ignoredOps.pop_back(); |
1218 | |
1219 | while (replacedOps.size() != state.numReplacedOps) |
1220 | replacedOps.pop_back(); |
1221 | } |
1222 | |
1223 | void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { |
1224 | for (auto &rewrite : |
1225 | llvm::reverse(C: llvm::drop_begin(RangeOrContainer&: rewrites, N: numRewritesToKeep))) |
1226 | rewrite->rollback(); |
1227 | rewrites.resize(N: numRewritesToKeep); |
1228 | } |
1229 | |
1230 | LogicalResult ConversionPatternRewriterImpl::remapValues( |
1231 | StringRef valueDiagTag, std::optional<Location> inputLoc, |
1232 | PatternRewriter &rewriter, ValueRange values, |
1233 | SmallVectorImpl<Value> &remapped) { |
1234 | remapped.reserve(N: llvm::size(Range&: values)); |
1235 | |
1236 | SmallVector<Type, 1> legalTypes; |
1237 | for (const auto &it : llvm::enumerate(First&: values)) { |
1238 | Value operand = it.value(); |
1239 | Type origType = operand.getType(); |
1240 | |
1241 | // If a converter was provided, get the desired legal types for this |
1242 | // operand. |
1243 | Type desiredType; |
1244 | if (currentTypeConverter) { |
1245 | // If there is no legal conversion, fail to match this pattern. |
1246 | legalTypes.clear(); |
1247 | if (failed(result: currentTypeConverter->convertType(t: origType, results&: legalTypes))) { |
1248 | Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); |
1249 | notifyMatchFailure(loc: operandLoc, reasonCallback: [=](Diagnostic &diag) { |
1250 | diag << "unable to convert type for " << valueDiagTag << " #" |
1251 | << it.index() << ", type was " << origType; |
1252 | }); |
1253 | return failure(); |
1254 | } |
1255 | // TODO: There currently isn't any mechanism to do 1->N type conversion |
1256 | // via the PatternRewriter replacement API, so for now we just ignore it. |
1257 | if (legalTypes.size() == 1) |
1258 | desiredType = legalTypes.front(); |
1259 | } else { |
1260 | // TODO: What we should do here is just set `desiredType` to `origType` |
1261 | // and then handle the necessary type conversions after the conversion |
1262 | // process has finished. Unfortunately a lot of patterns currently rely on |
1263 | // receiving the new operands even if the types change, so we keep the |
1264 | // original behavior here for now until all of the patterns relying on |
1265 | // this get updated. |
1266 | } |
1267 | Value newOperand = mapping.lookupOrDefault(from: operand, desiredType); |
1268 | |
1269 | // Handle the case where the conversion was 1->1 and the new operand type |
1270 | // isn't legal. |
1271 | Type newOperandType = newOperand.getType(); |
1272 | if (currentTypeConverter && desiredType && newOperandType != desiredType) { |
1273 | Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); |
1274 | Value castValue = buildUnresolvedTargetMaterialization( |
1275 | loc: operandLoc, input: newOperand, outputType: desiredType, converter: currentTypeConverter); |
1276 | mapping.map(oldVal: mapping.lookupOrDefault(from: newOperand), newVal: castValue); |
1277 | newOperand = castValue; |
1278 | } |
1279 | remapped.push_back(Elt: newOperand); |
1280 | } |
1281 | return success(); |
1282 | } |
1283 | |
1284 | bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { |
1285 | // Check to see if this operation is ignored or was replaced. |
1286 | return replacedOps.count(key: op) || ignoredOps.count(key: op); |
1287 | } |
1288 | |
1289 | bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const { |
1290 | // Check to see if this operation was replaced. |
1291 | return replacedOps.count(key: op); |
1292 | } |
1293 | |
1294 | //===----------------------------------------------------------------------===// |
1295 | // Type Conversion |
1296 | |
1297 | FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature( |
1298 | ConversionPatternRewriter &rewriter, Block *block, |
1299 | const TypeConverter *converter, |
1300 | TypeConverter::SignatureConversion *conversion) { |
1301 | if (conversion) |
1302 | return applySignatureConversion(rewriter, block, converter, signatureConversion&: *conversion); |
1303 | |
1304 | // If a converter wasn't provided, and the block wasn't already converted, |
1305 | // there is nothing we can do. |
1306 | if (!converter) |
1307 | return failure(); |
1308 | |
1309 | // Try to convert the signature for the block with the provided converter. |
1310 | if (auto conversion = converter->convertBlockSignature(block)) |
1311 | return applySignatureConversion(rewriter, block, converter, signatureConversion&: *conversion); |
1312 | return failure(); |
1313 | } |
1314 | |
1315 | Block *ConversionPatternRewriterImpl::applySignatureConversion( |
1316 | ConversionPatternRewriter &rewriter, Region *region, |
1317 | TypeConverter::SignatureConversion &conversion, |
1318 | const TypeConverter *converter) { |
1319 | if (!region->empty()) |
1320 | return *convertBlockSignature(rewriter, block: ®ion->front(), converter, |
1321 | conversion: &conversion); |
1322 | return nullptr; |
1323 | } |
1324 | |
1325 | FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes( |
1326 | ConversionPatternRewriter &rewriter, Region *region, |
1327 | const TypeConverter &converter, |
1328 | TypeConverter::SignatureConversion *entryConversion) { |
1329 | regionToConverter[region] = &converter; |
1330 | if (region->empty()) |
1331 | return nullptr; |
1332 | |
1333 | if (failed(result: convertNonEntryRegionTypes(rewriter, region, converter))) |
1334 | return failure(); |
1335 | |
1336 | FailureOr<Block *> newEntry = convertBlockSignature( |
1337 | rewriter, block: ®ion->front(), converter: &converter, conversion: entryConversion); |
1338 | return newEntry; |
1339 | } |
1340 | |
1341 | LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes( |
1342 | ConversionPatternRewriter &rewriter, Region *region, |
1343 | const TypeConverter &converter, |
1344 | ArrayRef<TypeConverter::SignatureConversion> blockConversions) { |
1345 | regionToConverter[region] = &converter; |
1346 | if (region->empty()) |
1347 | return success(); |
1348 | |
1349 | // Convert the arguments of each block within the region. |
1350 | int blockIdx = 0; |
1351 | assert((blockConversions.empty() || |
1352 | blockConversions.size() == region->getBlocks().size() - 1) && |
1353 | "expected either to provide no SignatureConversions at all or to " |
1354 | "provide a SignatureConversion for each non-entry block" ); |
1355 | |
1356 | for (Block &block : |
1357 | llvm::make_early_inc_range(Range: llvm::drop_begin(RangeOrContainer&: *region, N: 1))) { |
1358 | TypeConverter::SignatureConversion *blockConversion = |
1359 | blockConversions.empty() |
1360 | ? nullptr |
1361 | : const_cast<TypeConverter::SignatureConversion *>( |
1362 | &blockConversions[blockIdx++]); |
1363 | |
1364 | if (failed(result: convertBlockSignature(rewriter, block: &block, converter: &converter, |
1365 | conversion: blockConversion))) |
1366 | return failure(); |
1367 | } |
1368 | return success(); |
1369 | } |
1370 | |
1371 | Block *ConversionPatternRewriterImpl::applySignatureConversion( |
1372 | ConversionPatternRewriter &rewriter, Block *block, |
1373 | const TypeConverter *converter, |
1374 | TypeConverter::SignatureConversion &signatureConversion) { |
1375 | OpBuilder::InsertionGuard g(rewriter); |
1376 | |
1377 | // If no arguments are being changed or added, there is nothing to do. |
1378 | unsigned origArgCount = block->getNumArguments(); |
1379 | auto convertedTypes = signatureConversion.getConvertedTypes(); |
1380 | if (llvm::equal(LRange: block->getArgumentTypes(), RRange&: convertedTypes)) |
1381 | return block; |
1382 | |
1383 | // Compute the locations of all block arguments in the new block. |
1384 | SmallVector<Location> newLocs(convertedTypes.size(), |
1385 | rewriter.getUnknownLoc()); |
1386 | for (unsigned i = 0; i < origArgCount; ++i) { |
1387 | auto inputMap = signatureConversion.getInputMapping(input: i); |
1388 | if (!inputMap || inputMap->replacementValue) |
1389 | continue; |
1390 | Location origLoc = block->getArgument(i).getLoc(); |
1391 | for (unsigned j = 0; j < inputMap->size; ++j) |
1392 | newLocs[inputMap->inputNo + j] = origLoc; |
1393 | } |
1394 | |
1395 | // Insert a new block with the converted block argument types and move all ops |
1396 | // from the old block to the new block. |
1397 | Block *newBlock = |
1398 | rewriter.createBlock(parent: block->getParent(), insertPt: std::next(x: block->getIterator()), |
1399 | argTypes: convertedTypes, locs: newLocs); |
1400 | |
1401 | // If a listener is attached to the dialect conversion, ops cannot be moved |
1402 | // to the destination block in bulk ("fast path"). This is because at the time |
1403 | // the notifications are sent, it is unknown which ops were moved. Instead, |
1404 | // ops should be moved one-by-one ("slow path"), so that a separate |
1405 | // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is |
1406 | // a bit more efficient, so we try to do that when possible. |
1407 | bool fastPath = !config.listener; |
1408 | if (fastPath) { |
1409 | appendRewrite<InlineBlockRewrite>(args&: newBlock, args&: block, args: newBlock->end()); |
1410 | newBlock->getOperations().splice(where: newBlock->end(), L2&: block->getOperations()); |
1411 | } else { |
1412 | while (!block->empty()) |
1413 | rewriter.moveOpBefore(op: &block->front(), block: newBlock, iterator: newBlock->end()); |
1414 | } |
1415 | |
1416 | // Replace all uses of the old block with the new block. |
1417 | block->replaceAllUsesWith(newValue&: newBlock); |
1418 | |
1419 | // Remap each of the original arguments as determined by the signature |
1420 | // conversion. |
1421 | SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo; |
1422 | argInfo.resize(N: origArgCount); |
1423 | |
1424 | for (unsigned i = 0; i != origArgCount; ++i) { |
1425 | auto inputMap = signatureConversion.getInputMapping(input: i); |
1426 | if (!inputMap) |
1427 | continue; |
1428 | BlockArgument origArg = block->getArgument(i); |
1429 | |
1430 | // If inputMap->replacementValue is not nullptr, then the argument is |
1431 | // dropped and a replacement value is provided to be the remappedValue. |
1432 | if (inputMap->replacementValue) { |
1433 | assert(inputMap->size == 0 && |
1434 | "invalid to provide a replacement value when the argument isn't " |
1435 | "dropped" ); |
1436 | mapping.map(oldVal: origArg, newVal: inputMap->replacementValue); |
1437 | appendRewrite<ReplaceBlockArgRewrite>(args&: block, args&: origArg); |
1438 | continue; |
1439 | } |
1440 | |
1441 | // Otherwise, this is a 1->1+ mapping. |
1442 | auto replArgs = |
1443 | newBlock->getArguments().slice(N: inputMap->inputNo, M: inputMap->size); |
1444 | Value newArg; |
1445 | |
1446 | // If this is a 1->1 mapping and the types of new and replacement arguments |
1447 | // match (i.e. it's an identity map), then the argument is mapped to its |
1448 | // original type. |
1449 | // FIXME: We simply pass through the replacement argument if there wasn't a |
1450 | // converter, which isn't great as it allows implicit type conversions to |
1451 | // appear. We should properly restructure this code to handle cases where a |
1452 | // converter isn't provided and also to properly handle the case where an |
1453 | // argument materialization is actually a temporary source materialization |
1454 | // (e.g. in the case of 1->N). |
1455 | if (replArgs.size() == 1 && |
1456 | (!converter || replArgs[0].getType() == origArg.getType())) { |
1457 | newArg = replArgs.front(); |
1458 | } else { |
1459 | Type origOutputType = origArg.getType(); |
1460 | |
1461 | // Legalize the argument output type. |
1462 | Type outputType = origOutputType; |
1463 | if (Type legalOutputType = converter->convertType(t: outputType)) |
1464 | outputType = legalOutputType; |
1465 | |
1466 | newArg = buildUnresolvedArgumentMaterialization( |
1467 | block: newBlock, loc: origArg.getLoc(), inputs: replArgs, origOutputType, outputType, |
1468 | converter); |
1469 | } |
1470 | |
1471 | mapping.map(oldVal: origArg, newVal: newArg); |
1472 | appendRewrite<ReplaceBlockArgRewrite>(args&: block, args&: origArg); |
1473 | argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg); |
1474 | } |
1475 | |
1476 | appendRewrite<BlockTypeConversionRewrite>(args&: newBlock, args&: block, args&: argInfo, |
1477 | args&: converter); |
1478 | |
1479 | // Erase the old block. (It is just unlinked for now and will be erased during |
1480 | // cleanup.) |
1481 | rewriter.eraseBlock(block); |
1482 | |
1483 | return newBlock; |
1484 | } |
1485 | |
1486 | //===----------------------------------------------------------------------===// |
1487 | // Materializations |
1488 | //===----------------------------------------------------------------------===// |
1489 | |
1490 | /// Build an unresolved materialization operation given an output type and set |
1491 | /// of input operands. |
1492 | Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( |
1493 | MaterializationKind kind, Block *insertBlock, Block::iterator insertPt, |
1494 | Location loc, ValueRange inputs, Type outputType, Type origOutputType, |
1495 | const TypeConverter *converter) { |
1496 | // Avoid materializing an unnecessary cast. |
1497 | if (inputs.size() == 1 && inputs.front().getType() == outputType) |
1498 | return inputs.front(); |
1499 | |
1500 | // Create an unresolved materialization. We use a new OpBuilder to avoid |
1501 | // tracking the materialization like we do for other operations. |
1502 | OpBuilder builder(insertBlock, insertPt); |
1503 | auto convertOp = |
1504 | builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs); |
1505 | appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind, |
1506 | origOutputType); |
1507 | return convertOp.getResult(0); |
1508 | } |
1509 | Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization( |
1510 | Block *block, Location loc, ValueRange inputs, Type origOutputType, |
1511 | Type outputType, const TypeConverter *converter) { |
1512 | return buildUnresolvedMaterialization(kind: MaterializationKind::Argument, insertBlock: block, |
1513 | insertPt: block->begin(), loc, inputs, outputType, |
1514 | origOutputType, converter); |
1515 | } |
1516 | Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization( |
1517 | Location loc, Value input, Type outputType, |
1518 | const TypeConverter *converter) { |
1519 | Block *insertBlock = input.getParentBlock(); |
1520 | Block::iterator insertPt = insertBlock->begin(); |
1521 | if (OpResult inputRes = dyn_cast<OpResult>(Val&: input)) |
1522 | insertPt = ++inputRes.getOwner()->getIterator(); |
1523 | |
1524 | return buildUnresolvedMaterialization(kind: MaterializationKind::Target, |
1525 | insertBlock, insertPt, loc, inputs: input, |
1526 | outputType, origOutputType: outputType, converter); |
1527 | } |
1528 | |
1529 | //===----------------------------------------------------------------------===// |
1530 | // Rewriter Notification Hooks |
1531 | |
1532 | void ConversionPatternRewriterImpl::notifyOperationInserted( |
1533 | Operation *op, OpBuilder::InsertPoint previous) { |
1534 | LLVM_DEBUG({ |
1535 | logger.startLine() << "** Insert : '" << op->getName() << "'(" << op |
1536 | << ")\n" ; |
1537 | }); |
1538 | assert(!wasOpReplaced(op->getParentOp()) && |
1539 | "attempting to insert into a block within a replaced/erased op" ); |
1540 | |
1541 | if (!previous.isSet()) { |
1542 | // This is a newly created op. |
1543 | appendRewrite<CreateOperationRewrite>(args&: op); |
1544 | return; |
1545 | } |
1546 | Operation *prevOp = previous.getPoint() == previous.getBlock()->end() |
1547 | ? nullptr |
1548 | : &*previous.getPoint(); |
1549 | appendRewrite<MoveOperationRewrite>(args&: op, args: previous.getBlock(), args&: prevOp); |
1550 | } |
1551 | |
1552 | void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, |
1553 | ValueRange newValues) { |
1554 | assert(newValues.size() == op->getNumResults()); |
1555 | assert(!ignoredOps.contains(op) && "operation was already replaced" ); |
1556 | |
1557 | // Track if any of the results changed, e.g. erased and replaced with null. |
1558 | bool resultChanged = false; |
1559 | |
1560 | // Create mappings for each of the new result values. |
1561 | for (auto [newValue, result] : llvm::zip(t&: newValues, u: op->getResults())) { |
1562 | if (!newValue) { |
1563 | resultChanged = true; |
1564 | continue; |
1565 | } |
1566 | // Remap, and check for any result type changes. |
1567 | mapping.map(oldVal: result, newVal: newValue); |
1568 | resultChanged |= (newValue.getType() != result.getType()); |
1569 | } |
1570 | |
1571 | appendRewrite<ReplaceOperationRewrite>(args&: op, args&: currentTypeConverter, |
1572 | args&: resultChanged); |
1573 | |
1574 | // Mark this operation and all nested ops as replaced. |
1575 | op->walk(callback: [&](Operation *op) { replacedOps.insert(X: op); }); |
1576 | } |
1577 | |
1578 | void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) { |
1579 | Region *region = block->getParent(); |
1580 | Block *origNextBlock = block->getNextNode(); |
1581 | appendRewrite<EraseBlockRewrite>(args&: block, args&: region, args&: origNextBlock); |
1582 | } |
1583 | |
1584 | void ConversionPatternRewriterImpl::notifyBlockInserted( |
1585 | Block *block, Region *previous, Region::iterator previousIt) { |
1586 | assert(!wasOpReplaced(block->getParentOp()) && |
1587 | "attempting to insert into a region within a replaced/erased op" ); |
1588 | LLVM_DEBUG( |
1589 | { |
1590 | Operation *parent = block->getParentOp(); |
1591 | if (parent) { |
1592 | logger.startLine() << "** Insert Block into : '" << parent->getName() |
1593 | << "'(" << parent << ")\n" ; |
1594 | } else { |
1595 | logger.startLine() |
1596 | << "** Insert Block into detached Region (nullptr parent op)'" ; |
1597 | } |
1598 | }); |
1599 | |
1600 | if (!previous) { |
1601 | // This is a newly created block. |
1602 | appendRewrite<CreateBlockRewrite>(args&: block); |
1603 | return; |
1604 | } |
1605 | Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt; |
1606 | appendRewrite<MoveBlockRewrite>(args&: block, args&: previous, args&: prevBlock); |
1607 | } |
1608 | |
1609 | void ConversionPatternRewriterImpl::notifyBlockBeingInlined( |
1610 | Block *block, Block *srcBlock, Block::iterator before) { |
1611 | appendRewrite<InlineBlockRewrite>(args&: block, args&: srcBlock, args&: before); |
1612 | } |
1613 | |
1614 | void ConversionPatternRewriterImpl::notifyMatchFailure( |
1615 | Location loc, function_ref<void(Diagnostic &)> reasonCallback) { |
1616 | LLVM_DEBUG({ |
1617 | Diagnostic diag(loc, DiagnosticSeverity::Remark); |
1618 | reasonCallback(diag); |
1619 | logger.startLine() << "** Failure : " << diag.str() << "\n" ; |
1620 | if (config.notifyCallback) |
1621 | config.notifyCallback(diag); |
1622 | }); |
1623 | } |
1624 | |
1625 | //===----------------------------------------------------------------------===// |
1626 | // ConversionPatternRewriter |
1627 | //===----------------------------------------------------------------------===// |
1628 | |
1629 | ConversionPatternRewriter::ConversionPatternRewriter( |
1630 | MLIRContext *ctx, const ConversionConfig &config) |
1631 | : PatternRewriter(ctx), |
1632 | impl(new detail::ConversionPatternRewriterImpl(ctx, config)) { |
1633 | setListener(impl.get()); |
1634 | } |
1635 | |
1636 | ConversionPatternRewriter::~ConversionPatternRewriter() = default; |
1637 | |
1638 | void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) { |
1639 | assert(op && newOp && "expected non-null op" ); |
1640 | replaceOp(op, newValues: newOp->getResults()); |
1641 | } |
1642 | |
1643 | void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { |
1644 | assert(op->getNumResults() == newValues.size() && |
1645 | "incorrect # of replacement values" ); |
1646 | LLVM_DEBUG({ |
1647 | impl->logger.startLine() |
1648 | << "** Replace : '" << op->getName() << "'(" << op << ")\n" ; |
1649 | }); |
1650 | impl->notifyOpReplaced(op, newValues); |
1651 | } |
1652 | |
1653 | void ConversionPatternRewriter::eraseOp(Operation *op) { |
1654 | LLVM_DEBUG({ |
1655 | impl->logger.startLine() |
1656 | << "** Erase : '" << op->getName() << "'(" << op << ")\n" ; |
1657 | }); |
1658 | SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr); |
1659 | impl->notifyOpReplaced(op, newValues: nullRepls); |
1660 | } |
1661 | |
1662 | void ConversionPatternRewriter::eraseBlock(Block *block) { |
1663 | assert(!impl->wasOpReplaced(block->getParentOp()) && |
1664 | "attempting to erase a block within a replaced/erased op" ); |
1665 | |
1666 | // Mark all ops for erasure. |
1667 | for (Operation &op : *block) |
1668 | eraseOp(op: &op); |
1669 | |
1670 | // Unlink the block from its parent region. The block is kept in the rewrite |
1671 | // object and will be actually destroyed when rewrites are applied. This |
1672 | // allows us to keep the operations in the block live and undo the removal by |
1673 | // re-inserting the block. |
1674 | impl->notifyBlockIsBeingErased(block); |
1675 | block->getParent()->getBlocks().remove(IT: block); |
1676 | } |
1677 | |
1678 | Block *ConversionPatternRewriter::applySignatureConversion( |
1679 | Region *region, TypeConverter::SignatureConversion &conversion, |
1680 | const TypeConverter *converter) { |
1681 | assert(!impl->wasOpReplaced(region->getParentOp()) && |
1682 | "attempting to apply a signature conversion to a block within a " |
1683 | "replaced/erased op" ); |
1684 | return impl->applySignatureConversion(rewriter&: *this, region, conversion, converter); |
1685 | } |
1686 | |
1687 | FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( |
1688 | Region *region, const TypeConverter &converter, |
1689 | TypeConverter::SignatureConversion *entryConversion) { |
1690 | assert(!impl->wasOpReplaced(region->getParentOp()) && |
1691 | "attempting to apply a signature conversion to a block within a " |
1692 | "replaced/erased op" ); |
1693 | return impl->convertRegionTypes(rewriter&: *this, region, converter, entryConversion); |
1694 | } |
1695 | |
1696 | LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes( |
1697 | Region *region, const TypeConverter &converter, |
1698 | ArrayRef<TypeConverter::SignatureConversion> blockConversions) { |
1699 | assert(!impl->wasOpReplaced(region->getParentOp()) && |
1700 | "attempting to apply a signature conversion to a block within a " |
1701 | "replaced/erased op" ); |
1702 | return impl->convertNonEntryRegionTypes(rewriter&: *this, region, converter, |
1703 | blockConversions); |
1704 | } |
1705 | |
1706 | void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, |
1707 | Value to) { |
1708 | LLVM_DEBUG({ |
1709 | Operation *parentOp = from.getOwner()->getParentOp(); |
1710 | impl->logger.startLine() << "** Replace Argument : '" << from |
1711 | << "'(in region of '" << parentOp->getName() |
1712 | << "'(" << from.getOwner()->getParentOp() << ")\n" ; |
1713 | }); |
1714 | impl->appendRewrite<ReplaceBlockArgRewrite>(args: from.getOwner(), args&: from); |
1715 | impl->mapping.map(oldVal: impl->mapping.lookupOrDefault(from), newVal: to); |
1716 | } |
1717 | |
1718 | Value ConversionPatternRewriter::getRemappedValue(Value key) { |
1719 | SmallVector<Value> remappedValues; |
1720 | if (failed(result: impl->remapValues(valueDiagTag: "value" , /*inputLoc=*/std::nullopt, rewriter&: *this, values: key, |
1721 | remapped&: remappedValues))) |
1722 | return nullptr; |
1723 | return remappedValues.front(); |
1724 | } |
1725 | |
1726 | LogicalResult |
1727 | ConversionPatternRewriter::getRemappedValues(ValueRange keys, |
1728 | SmallVectorImpl<Value> &results) { |
1729 | if (keys.empty()) |
1730 | return success(); |
1731 | return impl->remapValues(valueDiagTag: "value" , /*inputLoc=*/std::nullopt, rewriter&: *this, values: keys, |
1732 | remapped&: results); |
1733 | } |
1734 | |
1735 | void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, |
1736 | Block::iterator before, |
1737 | ValueRange argValues) { |
1738 | #ifndef NDEBUG |
1739 | assert(argValues.size() == source->getNumArguments() && |
1740 | "incorrect # of argument replacement values" ); |
1741 | assert(!impl->wasOpReplaced(source->getParentOp()) && |
1742 | "attempting to inline a block from a replaced/erased op" ); |
1743 | assert(!impl->wasOpReplaced(dest->getParentOp()) && |
1744 | "attempting to inline a block into a replaced/erased op" ); |
1745 | auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); }; |
1746 | // The source block will be deleted, so it should not have any users (i.e., |
1747 | // there should be no predecessors). |
1748 | assert(llvm::all_of(source->getUsers(), opIgnored) && |
1749 | "expected 'source' to have no predecessors" ); |
1750 | #endif // NDEBUG |
1751 | |
1752 | // If a listener is attached to the dialect conversion, ops cannot be moved |
1753 | // to the destination block in bulk ("fast path"). This is because at the time |
1754 | // the notifications are sent, it is unknown which ops were moved. Instead, |
1755 | // ops should be moved one-by-one ("slow path"), so that a separate |
1756 | // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is |
1757 | // a bit more efficient, so we try to do that when possible. |
1758 | bool fastPath = !impl->config.listener; |
1759 | |
1760 | if (fastPath) |
1761 | impl->notifyBlockBeingInlined(block: dest, srcBlock: source, before); |
1762 | |
1763 | // Replace all uses of block arguments. |
1764 | for (auto it : llvm::zip(t: source->getArguments(), u&: argValues)) |
1765 | replaceUsesOfBlockArgument(from: std::get<0>(t&: it), to: std::get<1>(t&: it)); |
1766 | |
1767 | if (fastPath) { |
1768 | // Move all ops at once. |
1769 | dest->getOperations().splice(where: before, L2&: source->getOperations()); |
1770 | } else { |
1771 | // Move op by op. |
1772 | while (!source->empty()) |
1773 | moveOpBefore(op: &source->front(), block: dest, iterator: before); |
1774 | } |
1775 | |
1776 | // Erase the source block. |
1777 | eraseBlock(block: source); |
1778 | } |
1779 | |
1780 | void ConversionPatternRewriter::startOpModification(Operation *op) { |
1781 | assert(!impl->wasOpReplaced(op) && |
1782 | "attempting to modify a replaced/erased op" ); |
1783 | #ifndef NDEBUG |
1784 | impl->pendingRootUpdates.insert(Ptr: op); |
1785 | #endif |
1786 | impl->appendRewrite<ModifyOperationRewrite>(args&: op); |
1787 | } |
1788 | |
1789 | void ConversionPatternRewriter::finalizeOpModification(Operation *op) { |
1790 | assert(!impl->wasOpReplaced(op) && |
1791 | "attempting to modify a replaced/erased op" ); |
1792 | PatternRewriter::finalizeOpModification(op); |
1793 | // There is nothing to do here, we only need to track the operation at the |
1794 | // start of the update. |
1795 | #ifndef NDEBUG |
1796 | assert(impl->pendingRootUpdates.erase(op) && |
1797 | "operation did not have a pending in-place update" ); |
1798 | #endif |
1799 | } |
1800 | |
1801 | void ConversionPatternRewriter::cancelOpModification(Operation *op) { |
1802 | #ifndef NDEBUG |
1803 | assert(impl->pendingRootUpdates.erase(op) && |
1804 | "operation did not have a pending in-place update" ); |
1805 | #endif |
1806 | // Erase the last update for this operation. |
1807 | auto it = llvm::find_if( |
1808 | Range: llvm::reverse(C&: impl->rewrites), P: [&](std::unique_ptr<IRRewrite> &rewrite) { |
1809 | auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(Val: rewrite.get()); |
1810 | return modifyRewrite && modifyRewrite->getOperation() == op; |
1811 | }); |
1812 | assert(it != impl->rewrites.rend() && "no root update started on op" ); |
1813 | (*it)->rollback(); |
1814 | int updateIdx = std::prev(x: impl->rewrites.rend()) - it; |
1815 | impl->rewrites.erase(CI: impl->rewrites.begin() + updateIdx); |
1816 | } |
1817 | |
1818 | detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { |
1819 | return *impl; |
1820 | } |
1821 | |
1822 | //===----------------------------------------------------------------------===// |
1823 | // ConversionPattern |
1824 | //===----------------------------------------------------------------------===// |
1825 | |
1826 | LogicalResult |
1827 | ConversionPattern::matchAndRewrite(Operation *op, |
1828 | PatternRewriter &rewriter) const { |
1829 | auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter); |
1830 | auto &rewriterImpl = dialectRewriter.getImpl(); |
1831 | |
1832 | // Track the current conversion pattern type converter in the rewriter. |
1833 | llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter, |
1834 | getTypeConverter()); |
1835 | |
1836 | // Remap the operands of the operation. |
1837 | SmallVector<Value, 4> operands; |
1838 | if (failed(result: rewriterImpl.remapValues(valueDiagTag: "operand" , inputLoc: op->getLoc(), rewriter, |
1839 | values: op->getOperands(), remapped&: operands))) { |
1840 | return failure(); |
1841 | } |
1842 | return matchAndRewrite(op, operands, rewriter&: dialectRewriter); |
1843 | } |
1844 | |
1845 | //===----------------------------------------------------------------------===// |
1846 | // OperationLegalizer |
1847 | //===----------------------------------------------------------------------===// |
1848 | |
1849 | namespace { |
1850 | /// A set of rewrite patterns that can be used to legalize a given operation. |
1851 | using LegalizationPatterns = SmallVector<const Pattern *, 1>; |
1852 | |
1853 | /// This class defines a recursive operation legalizer. |
1854 | class OperationLegalizer { |
1855 | public: |
1856 | using LegalizationAction = ConversionTarget::LegalizationAction; |
1857 | |
1858 | OperationLegalizer(const ConversionTarget &targetInfo, |
1859 | const FrozenRewritePatternSet &patterns, |
1860 | const ConversionConfig &config); |
1861 | |
1862 | /// Returns true if the given operation is known to be illegal on the target. |
1863 | bool isIllegal(Operation *op) const; |
1864 | |
1865 | /// Attempt to legalize the given operation. Returns success if the operation |
1866 | /// was legalized, failure otherwise. |
1867 | LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter); |
1868 | |
1869 | /// Returns the conversion target in use by the legalizer. |
1870 | const ConversionTarget &getTarget() { return target; } |
1871 | |
1872 | private: |
1873 | /// Attempt to legalize the given operation by folding it. |
1874 | LogicalResult legalizeWithFold(Operation *op, |
1875 | ConversionPatternRewriter &rewriter); |
1876 | |
1877 | /// Attempt to legalize the given operation by applying a pattern. Returns |
1878 | /// success if the operation was legalized, failure otherwise. |
1879 | LogicalResult legalizeWithPattern(Operation *op, |
1880 | ConversionPatternRewriter &rewriter); |
1881 | |
1882 | /// Return true if the given pattern may be applied to the given operation, |
1883 | /// false otherwise. |
1884 | bool canApplyPattern(Operation *op, const Pattern &pattern, |
1885 | ConversionPatternRewriter &rewriter); |
1886 | |
1887 | /// Legalize the resultant IR after successfully applying the given pattern. |
1888 | LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern, |
1889 | ConversionPatternRewriter &rewriter, |
1890 | RewriterState &curState); |
1891 | |
1892 | /// Legalizes the actions registered during the execution of a pattern. |
1893 | LogicalResult |
1894 | legalizePatternBlockRewrites(Operation *op, |
1895 | ConversionPatternRewriter &rewriter, |
1896 | ConversionPatternRewriterImpl &impl, |
1897 | RewriterState &state, RewriterState &newState); |
1898 | LogicalResult legalizePatternCreatedOperations( |
1899 | ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, |
1900 | RewriterState &state, RewriterState &newState); |
1901 | LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter, |
1902 | ConversionPatternRewriterImpl &impl, |
1903 | RewriterState &state, |
1904 | RewriterState &newState); |
1905 | |
1906 | //===--------------------------------------------------------------------===// |
1907 | // Cost Model |
1908 | //===--------------------------------------------------------------------===// |
1909 | |
1910 | /// Build an optimistic legalization graph given the provided patterns. This |
1911 | /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with |
1912 | /// patterns for operations that are not directly legal, but may be |
1913 | /// transitively legal for the current target given the provided patterns. |
1914 | void buildLegalizationGraph( |
1915 | LegalizationPatterns &anyOpLegalizerPatterns, |
1916 | DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); |
1917 | |
1918 | /// Compute the benefit of each node within the computed legalization graph. |
1919 | /// This orders the patterns within 'legalizerPatterns' based upon two |
1920 | /// criteria: |
1921 | /// 1) Prefer patterns that have the lowest legalization depth, i.e. |
1922 | /// represent the more direct mapping to the target. |
1923 | /// 2) When comparing patterns with the same legalization depth, prefer the |
1924 | /// pattern with the highest PatternBenefit. This allows for users to |
1925 | /// prefer specific legalizations over others. |
1926 | void computeLegalizationGraphBenefit( |
1927 | LegalizationPatterns &anyOpLegalizerPatterns, |
1928 | DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); |
1929 | |
1930 | /// Compute the legalization depth when legalizing an operation of the given |
1931 | /// type. |
1932 | unsigned computeOpLegalizationDepth( |
1933 | OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth, |
1934 | DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); |
1935 | |
1936 | /// Apply the conversion cost model to the given set of patterns, and return |
1937 | /// the smallest legalization depth of any of the patterns. See |
1938 | /// `computeLegalizationGraphBenefit` for the breakdown of the cost model. |
1939 | unsigned applyCostModelToPatterns( |
1940 | LegalizationPatterns &patterns, |
1941 | DenseMap<OperationName, unsigned> &minOpPatternDepth, |
1942 | DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); |
1943 | |
1944 | /// The current set of patterns that have been applied. |
1945 | SmallPtrSet<const Pattern *, 8> appliedPatterns; |
1946 | |
1947 | /// The legalization information provided by the target. |
1948 | const ConversionTarget ⌖ |
1949 | |
1950 | /// The pattern applicator to use for conversions. |
1951 | PatternApplicator applicator; |
1952 | |
1953 | /// Dialect conversion configuration. |
1954 | const ConversionConfig &config; |
1955 | }; |
1956 | } // namespace |
1957 | |
1958 | OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo, |
1959 | const FrozenRewritePatternSet &patterns, |
1960 | const ConversionConfig &config) |
1961 | : target(targetInfo), applicator(patterns), config(config) { |
1962 | // The set of patterns that can be applied to illegal operations to transform |
1963 | // them into legal ones. |
1964 | DenseMap<OperationName, LegalizationPatterns> legalizerPatterns; |
1965 | LegalizationPatterns anyOpLegalizerPatterns; |
1966 | |
1967 | buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns); |
1968 | computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns); |
1969 | } |
1970 | |
1971 | bool OperationLegalizer::isIllegal(Operation *op) const { |
1972 | return target.isIllegal(op); |
1973 | } |
1974 | |
1975 | LogicalResult |
1976 | OperationLegalizer::legalize(Operation *op, |
1977 | ConversionPatternRewriter &rewriter) { |
1978 | #ifndef NDEBUG |
1979 | const char * = |
1980 | "//===-------------------------------------------===//\n" ; |
1981 | |
1982 | auto &logger = rewriter.getImpl().logger; |
1983 | #endif |
1984 | LLVM_DEBUG({ |
1985 | logger.getOStream() << "\n" ; |
1986 | logger.startLine() << logLineComment; |
1987 | logger.startLine() << "Legalizing operation : '" << op->getName() << "'(" |
1988 | << op << ") {\n" ; |
1989 | logger.indent(); |
1990 | |
1991 | // If the operation has no regions, just print it here. |
1992 | if (op->getNumRegions() == 0) { |
1993 | op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm()); |
1994 | logger.getOStream() << "\n\n" ; |
1995 | } |
1996 | }); |
1997 | |
1998 | // Check if this operation is legal on the target. |
1999 | if (auto legalityInfo = target.isLegal(op)) { |
2000 | LLVM_DEBUG({ |
2001 | logSuccess( |
2002 | logger, "operation marked legal by the target{0}" , |
2003 | legalityInfo->isRecursivelyLegal |
2004 | ? "; NOTE: operation is recursively legal; skipping internals" |
2005 | : "" ); |
2006 | logger.startLine() << logLineComment; |
2007 | }); |
2008 | |
2009 | // If this operation is recursively legal, mark its children as ignored so |
2010 | // that we don't consider them for legalization. |
2011 | if (legalityInfo->isRecursivelyLegal) { |
2012 | op->walk(callback: [&](Operation *nested) { |
2013 | if (op != nested) |
2014 | rewriter.getImpl().ignoredOps.insert(X: nested); |
2015 | }); |
2016 | } |
2017 | |
2018 | return success(); |
2019 | } |
2020 | |
2021 | // Check to see if the operation is ignored and doesn't need to be converted. |
2022 | if (rewriter.getImpl().isOpIgnored(op)) { |
2023 | LLVM_DEBUG({ |
2024 | logSuccess(logger, "operation marked 'ignored' during conversion" ); |
2025 | logger.startLine() << logLineComment; |
2026 | }); |
2027 | return success(); |
2028 | } |
2029 | |
2030 | // If the operation isn't legal, try to fold it in-place. |
2031 | // TODO: Should we always try to do this, even if the op is |
2032 | // already legal? |
2033 | if (succeeded(result: legalizeWithFold(op, rewriter))) { |
2034 | LLVM_DEBUG({ |
2035 | logSuccess(logger, "operation was folded" ); |
2036 | logger.startLine() << logLineComment; |
2037 | }); |
2038 | return success(); |
2039 | } |
2040 | |
2041 | // Otherwise, we need to apply a legalization pattern to this operation. |
2042 | if (succeeded(result: legalizeWithPattern(op, rewriter))) { |
2043 | LLVM_DEBUG({ |
2044 | logSuccess(logger, "" ); |
2045 | logger.startLine() << logLineComment; |
2046 | }); |
2047 | return success(); |
2048 | } |
2049 | |
2050 | LLVM_DEBUG({ |
2051 | logFailure(logger, "no matched legalization pattern" ); |
2052 | logger.startLine() << logLineComment; |
2053 | }); |
2054 | return failure(); |
2055 | } |
2056 | |
2057 | LogicalResult |
2058 | OperationLegalizer::legalizeWithFold(Operation *op, |
2059 | ConversionPatternRewriter &rewriter) { |
2060 | auto &rewriterImpl = rewriter.getImpl(); |
2061 | RewriterState curState = rewriterImpl.getCurrentState(); |
2062 | |
2063 | LLVM_DEBUG({ |
2064 | rewriterImpl.logger.startLine() << "* Fold {\n" ; |
2065 | rewriterImpl.logger.indent(); |
2066 | }); |
2067 | |
2068 | // Try to fold the operation. |
2069 | SmallVector<Value, 2> replacementValues; |
2070 | rewriter.setInsertionPoint(op); |
2071 | if (failed(result: rewriter.tryFold(op, results&: replacementValues))) { |
2072 | LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold" )); |
2073 | return failure(); |
2074 | } |
2075 | // An empty list of replacement values indicates that the fold was in-place. |
2076 | // As the operation changed, a new legalization needs to be attempted. |
2077 | if (replacementValues.empty()) |
2078 | return legalize(op, rewriter); |
2079 | |
2080 | // Insert a replacement for 'op' with the folded replacement values. |
2081 | rewriter.replaceOp(op, newValues: replacementValues); |
2082 | |
2083 | // Recursively legalize any new constant operations. |
2084 | for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size(); |
2085 | i != e; ++i) { |
2086 | auto *createOp = |
2087 | dyn_cast<CreateOperationRewrite>(Val: rewriterImpl.rewrites[i].get()); |
2088 | if (!createOp) |
2089 | continue; |
2090 | if (failed(result: legalize(op: createOp->getOperation(), rewriter))) { |
2091 | LLVM_DEBUG(logFailure(rewriterImpl.logger, |
2092 | "failed to legalize generated constant '{0}'" , |
2093 | createOp->getOperation()->getName())); |
2094 | rewriterImpl.resetState(state: curState); |
2095 | return failure(); |
2096 | } |
2097 | } |
2098 | |
2099 | LLVM_DEBUG(logSuccess(rewriterImpl.logger, "" )); |
2100 | return success(); |
2101 | } |
2102 | |
2103 | LogicalResult |
2104 | OperationLegalizer::legalizeWithPattern(Operation *op, |
2105 | ConversionPatternRewriter &rewriter) { |
2106 | auto &rewriterImpl = rewriter.getImpl(); |
2107 | |
2108 | // Functor that returns if the given pattern may be applied. |
2109 | auto canApply = [&](const Pattern &pattern) { |
2110 | bool canApply = canApplyPattern(op, pattern, rewriter); |
2111 | if (canApply && config.listener) |
2112 | config.listener->notifyPatternBegin(pattern, op); |
2113 | return canApply; |
2114 | }; |
2115 | |
2116 | // Functor that cleans up the rewriter state after a pattern failed to match. |
2117 | RewriterState curState = rewriterImpl.getCurrentState(); |
2118 | auto onFailure = [&](const Pattern &pattern) { |
2119 | assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates" ); |
2120 | LLVM_DEBUG({ |
2121 | logFailure(rewriterImpl.logger, "pattern failed to match" ); |
2122 | if (rewriterImpl.config.notifyCallback) { |
2123 | Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark); |
2124 | diag << "Failed to apply pattern \"" << pattern.getDebugName() |
2125 | << "\" on op:\n" |
2126 | << *op; |
2127 | rewriterImpl.config.notifyCallback(diag); |
2128 | } |
2129 | }); |
2130 | if (config.listener) |
2131 | config.listener->notifyPatternEnd(pattern, status: failure()); |
2132 | rewriterImpl.resetState(state: curState); |
2133 | appliedPatterns.erase(Ptr: &pattern); |
2134 | }; |
2135 | |
2136 | // Functor that performs additional legalization when a pattern is |
2137 | // successfully applied. |
2138 | auto onSuccess = [&](const Pattern &pattern) { |
2139 | assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates" ); |
2140 | auto result = legalizePatternResult(op, pattern, rewriter, curState); |
2141 | appliedPatterns.erase(Ptr: &pattern); |
2142 | if (failed(result)) |
2143 | rewriterImpl.resetState(state: curState); |
2144 | if (config.listener) |
2145 | config.listener->notifyPatternEnd(pattern, status: result); |
2146 | return result; |
2147 | }; |
2148 | |
2149 | // Try to match and rewrite a pattern on this operation. |
2150 | return applicator.matchAndRewrite(op, rewriter, canApply, onFailure, |
2151 | onSuccess); |
2152 | } |
2153 | |
2154 | bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern, |
2155 | ConversionPatternRewriter &rewriter) { |
2156 | LLVM_DEBUG({ |
2157 | auto &os = rewriter.getImpl().logger; |
2158 | os.getOStream() << "\n" ; |
2159 | os.startLine() << "* Pattern : '" << op->getName() << " -> (" ; |
2160 | llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream()); |
2161 | os.getOStream() << ")' {\n" ; |
2162 | os.indent(); |
2163 | }); |
2164 | |
2165 | // Ensure that we don't cycle by not allowing the same pattern to be |
2166 | // applied twice in the same recursion stack if it is not known to be safe. |
2167 | if (!pattern.hasBoundedRewriteRecursion() && |
2168 | !appliedPatterns.insert(Ptr: &pattern).second) { |
2169 | LLVM_DEBUG( |
2170 | logFailure(rewriter.getImpl().logger, "pattern was already applied" )); |
2171 | return false; |
2172 | } |
2173 | return true; |
2174 | } |
2175 | |
2176 | LogicalResult |
2177 | OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, |
2178 | ConversionPatternRewriter &rewriter, |
2179 | RewriterState &curState) { |
2180 | auto &impl = rewriter.getImpl(); |
2181 | |
2182 | #ifndef NDEBUG |
2183 | assert(impl.pendingRootUpdates.empty() && "dangling root updates" ); |
2184 | // Check that the root was either replaced or updated in place. |
2185 | auto newRewrites = llvm::drop_begin(RangeOrContainer&: impl.rewrites, N: curState.numRewrites); |
2186 | auto replacedRoot = [&] { |
2187 | return hasRewrite<ReplaceOperationRewrite>(rewrites&: newRewrites, op); |
2188 | }; |
2189 | auto updatedRootInPlace = [&] { |
2190 | return hasRewrite<ModifyOperationRewrite>(rewrites&: newRewrites, op); |
2191 | }; |
2192 | assert((replacedRoot() || updatedRootInPlace()) && |
2193 | "expected pattern to replace the root operation" ); |
2194 | #endif // NDEBUG |
2195 | |
2196 | // Legalize each of the actions registered during application. |
2197 | RewriterState newState = impl.getCurrentState(); |
2198 | if (failed(result: legalizePatternBlockRewrites(op, rewriter, impl, state&: curState, |
2199 | newState)) || |
2200 | failed(result: legalizePatternRootUpdates(rewriter, impl, state&: curState, newState)) || |
2201 | failed(result: legalizePatternCreatedOperations(rewriter, impl, state&: curState, |
2202 | newState))) { |
2203 | return failure(); |
2204 | } |
2205 | |
2206 | LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully" )); |
2207 | return success(); |
2208 | } |
2209 | |
2210 | LogicalResult OperationLegalizer::legalizePatternBlockRewrites( |
2211 | Operation *op, ConversionPatternRewriter &rewriter, |
2212 | ConversionPatternRewriterImpl &impl, RewriterState &state, |
2213 | RewriterState &newState) { |
2214 | SmallPtrSet<Operation *, 16> operationsToIgnore; |
2215 | |
2216 | // If the pattern moved or created any blocks, make sure the types of block |
2217 | // arguments get legalized. |
2218 | for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { |
2219 | BlockRewrite *rewrite = dyn_cast<BlockRewrite>(Val: impl.rewrites[i].get()); |
2220 | if (!rewrite) |
2221 | continue; |
2222 | Block *block = rewrite->getBlock(); |
2223 | if (isa<BlockTypeConversionRewrite, EraseBlockRewrite, |
2224 | ReplaceBlockArgRewrite>(Val: rewrite)) |
2225 | continue; |
2226 | // Only check blocks outside of the current operation. |
2227 | Operation *parentOp = block->getParentOp(); |
2228 | if (!parentOp || parentOp == op || block->getNumArguments() == 0) |
2229 | continue; |
2230 | |
2231 | // If the region of the block has a type converter, try to convert the block |
2232 | // directly. |
2233 | if (auto *converter = impl.regionToConverter.lookup(Val: block->getParent())) { |
2234 | if (failed(result: impl.convertBlockSignature(rewriter, block, converter))) { |
2235 | LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved " |
2236 | "block" )); |
2237 | return failure(); |
2238 | } |
2239 | continue; |
2240 | } |
2241 | |
2242 | // Otherwise, check that this operation isn't one generated by this pattern. |
2243 | // This is because we will attempt to legalize the parent operation, and |
2244 | // blocks in regions created by this pattern will already be legalized later |
2245 | // on. If we haven't built the set yet, build it now. |
2246 | if (operationsToIgnore.empty()) { |
2247 | for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e; |
2248 | ++i) { |
2249 | auto *createOp = |
2250 | dyn_cast<CreateOperationRewrite>(Val: impl.rewrites[i].get()); |
2251 | if (!createOp) |
2252 | continue; |
2253 | operationsToIgnore.insert(Ptr: createOp->getOperation()); |
2254 | } |
2255 | } |
2256 | |
2257 | // If this operation should be considered for re-legalization, try it. |
2258 | if (operationsToIgnore.insert(Ptr: parentOp).second && |
2259 | failed(result: legalize(op: parentOp, rewriter))) { |
2260 | LLVM_DEBUG(logFailure(impl.logger, |
2261 | "operation '{0}'({1}) became illegal after rewrite" , |
2262 | parentOp->getName(), parentOp)); |
2263 | return failure(); |
2264 | } |
2265 | } |
2266 | return success(); |
2267 | } |
2268 | |
2269 | LogicalResult OperationLegalizer::legalizePatternCreatedOperations( |
2270 | ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, |
2271 | RewriterState &state, RewriterState &newState) { |
2272 | for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { |
2273 | auto *createOp = dyn_cast<CreateOperationRewrite>(Val: impl.rewrites[i].get()); |
2274 | if (!createOp) |
2275 | continue; |
2276 | Operation *op = createOp->getOperation(); |
2277 | if (failed(result: legalize(op, rewriter))) { |
2278 | LLVM_DEBUG(logFailure(impl.logger, |
2279 | "failed to legalize generated operation '{0}'({1})" , |
2280 | op->getName(), op)); |
2281 | return failure(); |
2282 | } |
2283 | } |
2284 | return success(); |
2285 | } |
2286 | |
2287 | LogicalResult OperationLegalizer::legalizePatternRootUpdates( |
2288 | ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, |
2289 | RewriterState &state, RewriterState &newState) { |
2290 | for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { |
2291 | auto *rewrite = dyn_cast<ModifyOperationRewrite>(Val: impl.rewrites[i].get()); |
2292 | if (!rewrite) |
2293 | continue; |
2294 | Operation *op = rewrite->getOperation(); |
2295 | if (failed(result: legalize(op, rewriter))) { |
2296 | LLVM_DEBUG(logFailure( |
2297 | impl.logger, "failed to legalize operation updated in-place '{0}'" , |
2298 | op->getName())); |
2299 | return failure(); |
2300 | } |
2301 | } |
2302 | return success(); |
2303 | } |
2304 | |
2305 | //===----------------------------------------------------------------------===// |
2306 | // Cost Model |
2307 | |
2308 | void OperationLegalizer::buildLegalizationGraph( |
2309 | LegalizationPatterns &anyOpLegalizerPatterns, |
2310 | DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { |
2311 | // A mapping between an operation and a set of operations that can be used to |
2312 | // generate it. |
2313 | DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps; |
2314 | // A mapping between an operation and any currently invalid patterns it has. |
2315 | DenseMap<OperationName, SmallPtrSet<const Pattern *, 2>> invalidPatterns; |
2316 | // A worklist of patterns to consider for legality. |
2317 | SetVector<const Pattern *> patternWorklist; |
2318 | |
2319 | // Build the mapping from operations to the parent ops that may generate them. |
2320 | applicator.walkAllPatterns(walk: [&](const Pattern &pattern) { |
2321 | std::optional<OperationName> root = pattern.getRootKind(); |
2322 | |
2323 | // If the pattern has no specific root, we can't analyze the relationship |
2324 | // between the root op and generated operations. Given that, add all such |
2325 | // patterns to the legalization set. |
2326 | if (!root) { |
2327 | anyOpLegalizerPatterns.push_back(Elt: &pattern); |
2328 | return; |
2329 | } |
2330 | |
2331 | // Skip operations that are always known to be legal. |
2332 | if (target.getOpAction(op: *root) == LegalizationAction::Legal) |
2333 | return; |
2334 | |
2335 | // Add this pattern to the invalid set for the root op and record this root |
2336 | // as a parent for any generated operations. |
2337 | invalidPatterns[*root].insert(Ptr: &pattern); |
2338 | for (auto op : pattern.getGeneratedOps()) |
2339 | parentOps[op].insert(Ptr: *root); |
2340 | |
2341 | // Add this pattern to the worklist. |
2342 | patternWorklist.insert(X: &pattern); |
2343 | }); |
2344 | |
2345 | // If there are any patterns that don't have a specific root kind, we can't |
2346 | // make direct assumptions about what operations will never be legalized. |
2347 | // Note: Technically we could, but it would require an analysis that may |
2348 | // recurse into itself. It would be better to perform this kind of filtering |
2349 | // at a higher level than here anyways. |
2350 | if (!anyOpLegalizerPatterns.empty()) { |
2351 | for (const Pattern *pattern : patternWorklist) |
2352 | legalizerPatterns[*pattern->getRootKind()].push_back(Elt: pattern); |
2353 | return; |
2354 | } |
2355 | |
2356 | while (!patternWorklist.empty()) { |
2357 | auto *pattern = patternWorklist.pop_back_val(); |
2358 | |
2359 | // Check to see if any of the generated operations are invalid. |
2360 | if (llvm::any_of(Range: pattern->getGeneratedOps(), P: [&](OperationName op) { |
2361 | std::optional<LegalizationAction> action = target.getOpAction(op); |
2362 | return !legalizerPatterns.count(Val: op) && |
2363 | (!action || action == LegalizationAction::Illegal); |
2364 | })) |
2365 | continue; |
2366 | |
2367 | // Otherwise, if all of the generated operation are valid, this op is now |
2368 | // legal so add all of the child patterns to the worklist. |
2369 | legalizerPatterns[*pattern->getRootKind()].push_back(Elt: pattern); |
2370 | invalidPatterns[*pattern->getRootKind()].erase(Ptr: pattern); |
2371 | |
2372 | // Add any invalid patterns of the parent operations to see if they have now |
2373 | // become legal. |
2374 | for (auto op : parentOps[*pattern->getRootKind()]) |
2375 | patternWorklist.set_union(invalidPatterns[op]); |
2376 | } |
2377 | } |
2378 | |
2379 | void OperationLegalizer::computeLegalizationGraphBenefit( |
2380 | LegalizationPatterns &anyOpLegalizerPatterns, |
2381 | DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { |
2382 | // The smallest pattern depth, when legalizing an operation. |
2383 | DenseMap<OperationName, unsigned> minOpPatternDepth; |
2384 | |
2385 | // For each operation that is transitively legal, compute a cost for it. |
2386 | for (auto &opIt : legalizerPatterns) |
2387 | if (!minOpPatternDepth.count(Val: opIt.first)) |
2388 | computeOpLegalizationDepth(op: opIt.first, minOpPatternDepth, |
2389 | legalizerPatterns); |
2390 | |
2391 | // Apply the cost model to the patterns that can match any operation. Those |
2392 | // with a specific operation type are already resolved when computing the op |
2393 | // legalization depth. |
2394 | if (!anyOpLegalizerPatterns.empty()) |
2395 | applyCostModelToPatterns(patterns&: anyOpLegalizerPatterns, minOpPatternDepth, |
2396 | legalizerPatterns); |
2397 | |
2398 | // Apply a cost model to the pattern applicator. We order patterns first by |
2399 | // depth then benefit. `legalizerPatterns` contains per-op patterns by |
2400 | // decreasing benefit. |
2401 | applicator.applyCostModel(model: [&](const Pattern &pattern) { |
2402 | ArrayRef<const Pattern *> orderedPatternList; |
2403 | if (std::optional<OperationName> rootName = pattern.getRootKind()) |
2404 | orderedPatternList = legalizerPatterns[*rootName]; |
2405 | else |
2406 | orderedPatternList = anyOpLegalizerPatterns; |
2407 | |
2408 | // If the pattern is not found, then it was removed and cannot be matched. |
2409 | auto *it = llvm::find(Range&: orderedPatternList, Val: &pattern); |
2410 | if (it == orderedPatternList.end()) |
2411 | return PatternBenefit::impossibleToMatch(); |
2412 | |
2413 | // Patterns found earlier in the list have higher benefit. |
2414 | return PatternBenefit(std::distance(first: it, last: orderedPatternList.end())); |
2415 | }); |
2416 | } |
2417 | |
2418 | unsigned OperationLegalizer::computeOpLegalizationDepth( |
2419 | OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth, |
2420 | DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { |
2421 | // Check for existing depth. |
2422 | auto depthIt = minOpPatternDepth.find(Val: op); |
2423 | if (depthIt != minOpPatternDepth.end()) |
2424 | return depthIt->second; |
2425 | |
2426 | // If a mapping for this operation does not exist, then this operation |
2427 | // is always legal. Return 0 as the depth for a directly legal operation. |
2428 | auto opPatternsIt = legalizerPatterns.find(Val: op); |
2429 | if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty()) |
2430 | return 0u; |
2431 | |
2432 | // Record this initial depth in case we encounter this op again when |
2433 | // recursively computing the depth. |
2434 | minOpPatternDepth.try_emplace(Key: op, Args: std::numeric_limits<unsigned>::max()); |
2435 | |
2436 | // Apply the cost model to the operation patterns, and update the minimum |
2437 | // depth. |
2438 | unsigned minDepth = applyCostModelToPatterns( |
2439 | patterns&: opPatternsIt->second, minOpPatternDepth, legalizerPatterns); |
2440 | minOpPatternDepth[op] = minDepth; |
2441 | return minDepth; |
2442 | } |
2443 | |
2444 | unsigned OperationLegalizer::applyCostModelToPatterns( |
2445 | LegalizationPatterns &patterns, |
2446 | DenseMap<OperationName, unsigned> &minOpPatternDepth, |
2447 | DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { |
2448 | unsigned minDepth = std::numeric_limits<unsigned>::max(); |
2449 | |
2450 | // Compute the depth for each pattern within the set. |
2451 | SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth; |
2452 | patternsByDepth.reserve(N: patterns.size()); |
2453 | for (const Pattern *pattern : patterns) { |
2454 | unsigned depth = 1; |
2455 | for (auto generatedOp : pattern->getGeneratedOps()) { |
2456 | unsigned generatedOpDepth = computeOpLegalizationDepth( |
2457 | op: generatedOp, minOpPatternDepth, legalizerPatterns); |
2458 | depth = std::max(a: depth, b: generatedOpDepth + 1); |
2459 | } |
2460 | patternsByDepth.emplace_back(Args&: pattern, Args&: depth); |
2461 | |
2462 | // Update the minimum depth of the pattern list. |
2463 | minDepth = std::min(a: minDepth, b: depth); |
2464 | } |
2465 | |
2466 | // If the operation only has one legalization pattern, there is no need to |
2467 | // sort them. |
2468 | if (patternsByDepth.size() == 1) |
2469 | return minDepth; |
2470 | |
2471 | // Sort the patterns by those likely to be the most beneficial. |
2472 | std::stable_sort(first: patternsByDepth.begin(), last: patternsByDepth.end(), |
2473 | comp: [](const std::pair<const Pattern *, unsigned> &lhs, |
2474 | const std::pair<const Pattern *, unsigned> &rhs) { |
2475 | // First sort by the smaller pattern legalization |
2476 | // depth. |
2477 | if (lhs.second != rhs.second) |
2478 | return lhs.second < rhs.second; |
2479 | |
2480 | // Then sort by the larger pattern benefit. |
2481 | auto lhsBenefit = lhs.first->getBenefit(); |
2482 | auto rhsBenefit = rhs.first->getBenefit(); |
2483 | return lhsBenefit > rhsBenefit; |
2484 | }); |
2485 | |
2486 | // Update the legalization pattern to use the new sorted list. |
2487 | patterns.clear(); |
2488 | for (auto &patternIt : patternsByDepth) |
2489 | patterns.push_back(Elt: patternIt.first); |
2490 | return minDepth; |
2491 | } |
2492 | |
2493 | //===----------------------------------------------------------------------===// |
2494 | // OperationConverter |
2495 | //===----------------------------------------------------------------------===// |
2496 | namespace { |
2497 | enum OpConversionMode { |
2498 | /// In this mode, the conversion will ignore failed conversions to allow |
2499 | /// illegal operations to co-exist in the IR. |
2500 | Partial, |
2501 | |
2502 | /// In this mode, all operations must be legal for the given target for the |
2503 | /// conversion to succeed. |
2504 | Full, |
2505 | |
2506 | /// In this mode, operations are analyzed for legality. No actual rewrites are |
2507 | /// applied to the operations on success. |
2508 | Analysis, |
2509 | }; |
2510 | } // namespace |
2511 | |
2512 | namespace mlir { |
2513 | // This class converts operations to a given conversion target via a set of |
2514 | // rewrite patterns. The conversion behaves differently depending on the |
2515 | // conversion mode. |
2516 | struct OperationConverter { |
2517 | explicit OperationConverter(const ConversionTarget &target, |
2518 | const FrozenRewritePatternSet &patterns, |
2519 | const ConversionConfig &config, |
2520 | OpConversionMode mode) |
2521 | : config(config), opLegalizer(target, patterns, this->config), |
2522 | mode(mode) {} |
2523 | |
2524 | /// Converts the given operations to the conversion target. |
2525 | LogicalResult convertOperations(ArrayRef<Operation *> ops); |
2526 | |
2527 | private: |
2528 | /// Converts an operation with the given rewriter. |
2529 | LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op); |
2530 | |
2531 | /// This method is called after the conversion process to legalize any |
2532 | /// remaining artifacts and complete the conversion. |
2533 | LogicalResult finalize(ConversionPatternRewriter &rewriter); |
2534 | |
2535 | /// Legalize the types of converted block arguments. |
2536 | LogicalResult |
2537 | legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter, |
2538 | ConversionPatternRewriterImpl &rewriterImpl); |
2539 | |
2540 | /// Legalize any unresolved type materializations. |
2541 | LogicalResult legalizeUnresolvedMaterializations( |
2542 | ConversionPatternRewriter &rewriter, |
2543 | ConversionPatternRewriterImpl &rewriterImpl, |
2544 | std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping); |
2545 | |
2546 | /// Legalize an operation result that was marked as "erased". |
2547 | LogicalResult |
2548 | legalizeErasedResult(Operation *op, OpResult result, |
2549 | ConversionPatternRewriterImpl &rewriterImpl); |
2550 | |
2551 | /// Legalize an operation result that was replaced with a value of a different |
2552 | /// type. |
2553 | LogicalResult legalizeChangedResultType( |
2554 | Operation *op, OpResult result, Value newValue, |
2555 | const TypeConverter *replConverter, ConversionPatternRewriter &rewriter, |
2556 | ConversionPatternRewriterImpl &rewriterImpl, |
2557 | const DenseMap<Value, SmallVector<Value>> &inverseMapping); |
2558 | |
2559 | /// Dialect conversion configuration. |
2560 | ConversionConfig config; |
2561 | |
2562 | /// The legalizer to use when converting operations. |
2563 | OperationLegalizer opLegalizer; |
2564 | |
2565 | /// The conversion mode to use when legalizing operations. |
2566 | OpConversionMode mode; |
2567 | }; |
2568 | } // namespace mlir |
2569 | |
2570 | LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, |
2571 | Operation *op) { |
2572 | // Legalize the given operation. |
2573 | if (failed(result: opLegalizer.legalize(op, rewriter))) { |
2574 | // Handle the case of a failed conversion for each of the different modes. |
2575 | // Full conversions expect all operations to be converted. |
2576 | if (mode == OpConversionMode::Full) |
2577 | return op->emitError() |
2578 | << "failed to legalize operation '" << op->getName() << "'" ; |
2579 | // Partial conversions allow conversions to fail iff the operation was not |
2580 | // explicitly marked as illegal. If the user provided a `unlegalizedOps` |
2581 | // set, non-legalizable ops are added to that set. |
2582 | if (mode == OpConversionMode::Partial) { |
2583 | if (opLegalizer.isIllegal(op)) |
2584 | return op->emitError() |
2585 | << "failed to legalize operation '" << op->getName() |
2586 | << "' that was explicitly marked illegal" ; |
2587 | if (config.unlegalizedOps) |
2588 | config.unlegalizedOps->insert(V: op); |
2589 | } |
2590 | } else if (mode == OpConversionMode::Analysis) { |
2591 | // Analysis conversions don't fail if any operations fail to legalize, |
2592 | // they are only interested in the operations that were successfully |
2593 | // legalized. |
2594 | if (config.legalizableOps) |
2595 | config.legalizableOps->insert(V: op); |
2596 | } |
2597 | return success(); |
2598 | } |
2599 | |
2600 | LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { |
2601 | if (ops.empty()) |
2602 | return success(); |
2603 | const ConversionTarget &target = opLegalizer.getTarget(); |
2604 | |
2605 | // Compute the set of operations and blocks to convert. |
2606 | SmallVector<Operation *> toConvert; |
2607 | for (auto *op : ops) { |
2608 | op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>( |
2609 | callback: [&](Operation *op) { |
2610 | toConvert.push_back(Elt: op); |
2611 | // Don't check this operation's children for conversion if the |
2612 | // operation is recursively legal. |
2613 | auto legalityInfo = target.isLegal(op); |
2614 | if (legalityInfo && legalityInfo->isRecursivelyLegal) |
2615 | return WalkResult::skip(); |
2616 | return WalkResult::advance(); |
2617 | }); |
2618 | } |
2619 | |
2620 | // Convert each operation and discard rewrites on failure. |
2621 | ConversionPatternRewriter rewriter(ops.front()->getContext(), config); |
2622 | ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); |
2623 | |
2624 | for (auto *op : toConvert) |
2625 | if (failed(result: convert(rewriter, op))) |
2626 | return rewriterImpl.undoRewrites(), failure(); |
2627 | |
2628 | // Now that all of the operations have been converted, finalize the conversion |
2629 | // process to ensure any lingering conversion artifacts are cleaned up and |
2630 | // legalized. |
2631 | if (failed(result: finalize(rewriter))) |
2632 | return rewriterImpl.undoRewrites(), failure(); |
2633 | |
2634 | // After a successful conversion, apply rewrites if this is not an analysis |
2635 | // conversion. |
2636 | if (mode == OpConversionMode::Analysis) { |
2637 | rewriterImpl.undoRewrites(); |
2638 | } else { |
2639 | rewriterImpl.applyRewrites(); |
2640 | } |
2641 | return success(); |
2642 | } |
2643 | |
2644 | LogicalResult |
2645 | OperationConverter::finalize(ConversionPatternRewriter &rewriter) { |
2646 | std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping; |
2647 | ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); |
2648 | if (failed(result: legalizeUnresolvedMaterializations(rewriter, rewriterImpl, |
2649 | inverseMapping)) || |
2650 | failed(result: legalizeConvertedArgumentTypes(rewriter, rewriterImpl))) |
2651 | return failure(); |
2652 | |
2653 | // Process requested operation replacements. |
2654 | for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) { |
2655 | auto *opReplacement = |
2656 | dyn_cast<ReplaceOperationRewrite>(Val: rewriterImpl.rewrites[i].get()); |
2657 | if (!opReplacement || !opReplacement->hasChangedResults()) |
2658 | continue; |
2659 | Operation *op = opReplacement->getOperation(); |
2660 | for (OpResult result : op->getResults()) { |
2661 | Value newValue = rewriterImpl.mapping.lookupOrNull(from: result); |
2662 | |
2663 | // If the operation result was replaced with null, all of the uses of this |
2664 | // value should be replaced. |
2665 | if (!newValue) { |
2666 | if (failed(result: legalizeErasedResult(op, result, rewriterImpl))) |
2667 | return failure(); |
2668 | continue; |
2669 | } |
2670 | |
2671 | // Otherwise, check to see if the type of the result changed. |
2672 | if (result.getType() == newValue.getType()) |
2673 | continue; |
2674 | |
2675 | // Compute the inverse mapping only if it is really needed. |
2676 | if (!inverseMapping) |
2677 | inverseMapping = rewriterImpl.mapping.getInverse(); |
2678 | |
2679 | // Legalize this result. |
2680 | rewriter.setInsertionPoint(op); |
2681 | if (failed(result: legalizeChangedResultType( |
2682 | op, result, newValue, replConverter: opReplacement->getConverter(), rewriter, |
2683 | rewriterImpl, inverseMapping: *inverseMapping))) |
2684 | return failure(); |
2685 | } |
2686 | } |
2687 | return success(); |
2688 | } |
2689 | |
2690 | LogicalResult OperationConverter::legalizeConvertedArgumentTypes( |
2691 | ConversionPatternRewriter &rewriter, |
2692 | ConversionPatternRewriterImpl &rewriterImpl) { |
2693 | // Functor used to check if all users of a value will be dead after |
2694 | // conversion. |
2695 | auto findLiveUser = [&](Value val) { |
2696 | auto liveUserIt = llvm::find_if_not(Range: val.getUsers(), P: [&](Operation *user) { |
2697 | return rewriterImpl.isOpIgnored(op: user); |
2698 | }); |
2699 | return liveUserIt == val.user_end() ? nullptr : *liveUserIt; |
2700 | }; |
2701 | // Note: `rewrites` may be reallocated as the loop is running. |
2702 | for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size()); |
2703 | ++i) { |
2704 | auto &rewrite = rewriterImpl.rewrites[i]; |
2705 | if (auto *blockTypeConversionRewrite = |
2706 | dyn_cast<BlockTypeConversionRewrite>(Val: rewrite.get())) |
2707 | if (failed(result: blockTypeConversionRewrite->materializeLiveConversions( |
2708 | findLiveUser))) |
2709 | return failure(); |
2710 | } |
2711 | return success(); |
2712 | } |
2713 | |
2714 | /// Replace the results of a materialization operation with the given values. |
2715 | static void |
2716 | replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl, |
2717 | ResultRange matResults, ValueRange values, |
2718 | DenseMap<Value, SmallVector<Value>> &inverseMapping) { |
2719 | matResults.replaceAllUsesWith(values); |
2720 | |
2721 | // For each of the materialization results, update the inverse mappings to |
2722 | // point to the replacement values. |
2723 | for (auto [matResult, newValue] : llvm::zip(t&: matResults, u&: values)) { |
2724 | auto inverseMapIt = inverseMapping.find(Val: matResult); |
2725 | if (inverseMapIt == inverseMapping.end()) |
2726 | continue; |
2727 | |
2728 | // Update the reverse mapping, or remove the mapping if we couldn't update |
2729 | // it. Not being able to update signals that the mapping would have become |
2730 | // circular (i.e. %foo -> newValue -> %foo), which may occur as values are |
2731 | // propagated through temporary materializations. We simply drop the |
2732 | // mapping, and let the post-conversion replacement logic handle updating |
2733 | // uses. |
2734 | for (Value inverseMapVal : inverseMapIt->second) |
2735 | if (!rewriterImpl.mapping.tryMap(oldVal: inverseMapVal, newVal: newValue)) |
2736 | rewriterImpl.mapping.erase(value: inverseMapVal); |
2737 | } |
2738 | } |
2739 | |
2740 | /// Compute all of the unresolved materializations that will persist beyond the |
2741 | /// conversion process, and require inserting a proper user materialization for. |
2742 | static void computeNecessaryMaterializations( |
2743 | DenseMap<Operation *, UnresolvedMaterializationRewrite *> |
2744 | &materializationOps, |
2745 | ConversionPatternRewriter &rewriter, |
2746 | ConversionPatternRewriterImpl &rewriterImpl, |
2747 | DenseMap<Value, SmallVector<Value>> &inverseMapping, |
2748 | SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) { |
2749 | auto isLive = [&](Value value) { |
2750 | auto findFn = [&](Operation *user) { |
2751 | auto matIt = materializationOps.find(Val: user); |
2752 | if (matIt != materializationOps.end()) |
2753 | return !necessaryMaterializations.count(key: matIt->second); |
2754 | return rewriterImpl.isOpIgnored(op: user); |
2755 | }; |
2756 | // This value may be replacing another value that has a live user. |
2757 | for (Value inv : inverseMapping.lookup(Val: value)) |
2758 | if (llvm::find_if_not(Range: inv.getUsers(), P: findFn) != inv.user_end()) |
2759 | return true; |
2760 | // Or have live users itself. |
2761 | return llvm::find_if_not(Range: value.getUsers(), P: findFn) != value.user_end(); |
2762 | }; |
2763 | |
2764 | llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue = |
2765 | [&](Value invalidRoot, Value value, Type type) { |
2766 | // Check to see if the input operation was remapped to a variant of the |
2767 | // output. |
2768 | Value remappedValue = rewriterImpl.mapping.lookupOrDefault(from: value, desiredType: type); |
2769 | if (remappedValue.getType() == type && remappedValue != invalidRoot) |
2770 | return remappedValue; |
2771 | |
2772 | // Check to see if the input is a materialization operation that |
2773 | // provides an inverse conversion. We just check blindly for |
2774 | // UnrealizedConversionCastOp here, but it has no effect on correctness. |
2775 | auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>(); |
2776 | if (inputCastOp && inputCastOp->getNumOperands() == 1) |
2777 | return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0), |
2778 | type); |
2779 | |
2780 | return Value(); |
2781 | }; |
2782 | |
2783 | SetVector<UnresolvedMaterializationRewrite *> worklist; |
2784 | for (auto &rewrite : rewriterImpl.rewrites) { |
2785 | auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(Val: rewrite.get()); |
2786 | if (!mat) |
2787 | continue; |
2788 | materializationOps.try_emplace(mat->getOperation(), mat); |
2789 | worklist.insert(X: mat); |
2790 | } |
2791 | while (!worklist.empty()) { |
2792 | UnresolvedMaterializationRewrite *mat = worklist.pop_back_val(); |
2793 | UnrealizedConversionCastOp op = mat->getOperation(); |
2794 | |
2795 | // We currently only handle target materializations here. |
2796 | assert(op->getNumResults() == 1 && "unexpected materialization type" ); |
2797 | OpResult opResult = op->getOpResult(0); |
2798 | Type outputType = opResult.getType(); |
2799 | Operation::operand_range inputOperands = op.getOperands(); |
2800 | |
2801 | // Try to forward propagate operands for user conversion casts that result |
2802 | // in the input types of the current cast. |
2803 | for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) { |
2804 | auto castOp = dyn_cast<UnrealizedConversionCastOp>(user); |
2805 | if (!castOp) |
2806 | continue; |
2807 | if (castOp->getResultTypes() == inputOperands.getTypes()) { |
2808 | replaceMaterialization(rewriterImpl, opResult, inputOperands, |
2809 | inverseMapping); |
2810 | necessaryMaterializations.remove(materializationOps.lookup(user)); |
2811 | } |
2812 | } |
2813 | |
2814 | // Try to avoid materializing a resolved materialization if possible. |
2815 | // Handle the case of a 1-1 materialization. |
2816 | if (inputOperands.size() == 1) { |
2817 | // Check to see if the input operation was remapped to a variant of the |
2818 | // output. |
2819 | Value remappedValue = |
2820 | lookupRemappedValue(opResult, inputOperands[0], outputType); |
2821 | if (remappedValue && remappedValue != opResult) { |
2822 | replaceMaterialization(rewriterImpl, matResults: opResult, values: remappedValue, |
2823 | inverseMapping); |
2824 | necessaryMaterializations.remove(X: mat); |
2825 | continue; |
2826 | } |
2827 | } else { |
2828 | // TODO: Avoid materializing other types of conversions here. |
2829 | } |
2830 | |
2831 | // Check to see if this is an argument materialization. |
2832 | if (llvm::any_of(op->getOperands(), llvm::IsaPred<BlockArgument>) || |
2833 | llvm::any_of(inverseMapping[op->getResult(0)], |
2834 | llvm::IsaPred<BlockArgument>)) { |
2835 | mat->setMaterializationKind(MaterializationKind::Argument); |
2836 | } |
2837 | |
2838 | // If the materialization does not have any live users, we don't need to |
2839 | // generate a user materialization for it. |
2840 | // FIXME: For argument materializations, we currently need to check if any |
2841 | // of the inverse mapped values are used because some patterns expect blind |
2842 | // value replacement even if the types differ in some cases. When those |
2843 | // patterns are fixed, we can drop the argument special case here. |
2844 | bool isMaterializationLive = isLive(opResult); |
2845 | if (mat->getMaterializationKind() == MaterializationKind::Argument) |
2846 | isMaterializationLive |= llvm::any_of(Range&: inverseMapping[opResult], P: isLive); |
2847 | if (!isMaterializationLive) |
2848 | continue; |
2849 | if (!necessaryMaterializations.insert(X: mat)) |
2850 | continue; |
2851 | |
2852 | // Reprocess input materializations to see if they have an updated status. |
2853 | for (Value input : inputOperands) { |
2854 | if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) { |
2855 | if (auto *mat = materializationOps.lookup(parentOp)) |
2856 | worklist.insert(mat); |
2857 | } |
2858 | } |
2859 | } |
2860 | } |
2861 | |
2862 | /// Legalize the given unresolved materialization. Returns success if the |
2863 | /// materialization was legalized, failure otherise. |
2864 | static LogicalResult legalizeUnresolvedMaterialization( |
2865 | UnresolvedMaterializationRewrite &mat, |
2866 | DenseMap<Operation *, UnresolvedMaterializationRewrite *> |
2867 | &materializationOps, |
2868 | ConversionPatternRewriter &rewriter, |
2869 | ConversionPatternRewriterImpl &rewriterImpl, |
2870 | DenseMap<Value, SmallVector<Value>> &inverseMapping) { |
2871 | auto findLiveUser = [&](auto &&users) { |
2872 | auto liveUserIt = llvm::find_if_not( |
2873 | users, [&](Operation *user) { return rewriterImpl.isOpIgnored(op: user); }); |
2874 | return liveUserIt == users.end() ? nullptr : *liveUserIt; |
2875 | }; |
2876 | |
2877 | llvm::unique_function<Value(Value, Type)> lookupRemappedValue = |
2878 | [&](Value value, Type type) { |
2879 | // Check to see if the input operation was remapped to a variant of the |
2880 | // output. |
2881 | Value remappedValue = rewriterImpl.mapping.lookupOrDefault(from: value, desiredType: type); |
2882 | if (remappedValue.getType() == type) |
2883 | return remappedValue; |
2884 | return Value(); |
2885 | }; |
2886 | |
2887 | UnrealizedConversionCastOp op = mat.getOperation(); |
2888 | if (!rewriterImpl.ignoredOps.insert(op)) |
2889 | return success(); |
2890 | |
2891 | // We currently only handle target materializations here. |
2892 | OpResult opResult = op->getOpResult(0); |
2893 | Operation::operand_range inputOperands = op.getOperands(); |
2894 | Type outputType = opResult.getType(); |
2895 | |
2896 | // If any input to this materialization is another materialization, resolve |
2897 | // the input first. |
2898 | for (Value value : op->getOperands()) { |
2899 | auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>(); |
2900 | if (!valueCast) |
2901 | continue; |
2902 | |
2903 | auto matIt = materializationOps.find(valueCast); |
2904 | if (matIt != materializationOps.end()) |
2905 | if (failed(legalizeUnresolvedMaterialization( |
2906 | *matIt->second, materializationOps, rewriter, rewriterImpl, |
2907 | inverseMapping))) |
2908 | return failure(); |
2909 | } |
2910 | |
2911 | // Perform a last ditch attempt to avoid materializing a resolved |
2912 | // materialization if possible. |
2913 | // Handle the case of a 1-1 materialization. |
2914 | if (inputOperands.size() == 1) { |
2915 | // Check to see if the input operation was remapped to a variant of the |
2916 | // output. |
2917 | Value remappedValue = lookupRemappedValue(inputOperands[0], outputType); |
2918 | if (remappedValue && remappedValue != opResult) { |
2919 | replaceMaterialization(rewriterImpl, matResults: opResult, values: remappedValue, |
2920 | inverseMapping); |
2921 | return success(); |
2922 | } |
2923 | } else { |
2924 | // TODO: Avoid materializing other types of conversions here. |
2925 | } |
2926 | |
2927 | // Try to materialize the conversion. |
2928 | if (const TypeConverter *converter = mat.getConverter()) { |
2929 | // FIXME: Determine a suitable insertion location when there are multiple |
2930 | // inputs. |
2931 | if (inputOperands.size() == 1) |
2932 | rewriter.setInsertionPointAfterValue(inputOperands.front()); |
2933 | else |
2934 | rewriter.setInsertionPoint(op); |
2935 | |
2936 | Value newMaterialization; |
2937 | switch (mat.getMaterializationKind()) { |
2938 | case MaterializationKind::Argument: |
2939 | // Try to materialize an argument conversion. |
2940 | // FIXME: The current argument materialization hook expects the original |
2941 | // output type, even though it doesn't use that as the actual output type |
2942 | // of the generated IR. The output type is just used as an indicator of |
2943 | // the type of materialization to do. This behavior is really awkward in |
2944 | // that it diverges from the behavior of the other hooks, and can be |
2945 | // easily misunderstood. We should clean up the argument hooks to better |
2946 | // represent the desired invariants we actually care about. |
2947 | newMaterialization = converter->materializeArgumentConversion( |
2948 | builder&: rewriter, loc: op->getLoc(), resultType: mat.getOrigOutputType(), inputs: inputOperands); |
2949 | if (newMaterialization) |
2950 | break; |
2951 | |
2952 | // If an argument materialization failed, fallback to trying a target |
2953 | // materialization. |
2954 | [[fallthrough]]; |
2955 | case MaterializationKind::Target: |
2956 | newMaterialization = converter->materializeTargetConversion( |
2957 | builder&: rewriter, loc: op->getLoc(), resultType: outputType, inputs: inputOperands); |
2958 | break; |
2959 | } |
2960 | if (newMaterialization) { |
2961 | replaceMaterialization(rewriterImpl, matResults: opResult, values: newMaterialization, |
2962 | inverseMapping); |
2963 | return success(); |
2964 | } |
2965 | } |
2966 | |
2967 | InFlightDiagnostic diag = op->emitError() |
2968 | << "failed to legalize unresolved materialization " |
2969 | "from " |
2970 | << inputOperands.getTypes() << " to " << outputType |
2971 | << " that remained live after conversion" ; |
2972 | if (Operation *liveUser = findLiveUser(op->getUsers())) { |
2973 | diag.attachNote(noteLoc: liveUser->getLoc()) |
2974 | << "see existing live user here: " << *liveUser; |
2975 | } |
2976 | return failure(); |
2977 | } |
2978 | |
2979 | LogicalResult OperationConverter::legalizeUnresolvedMaterializations( |
2980 | ConversionPatternRewriter &rewriter, |
2981 | ConversionPatternRewriterImpl &rewriterImpl, |
2982 | std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) { |
2983 | inverseMapping = rewriterImpl.mapping.getInverse(); |
2984 | |
2985 | // As an initial step, compute all of the inserted materializations that we |
2986 | // expect to persist beyond the conversion process. |
2987 | DenseMap<Operation *, UnresolvedMaterializationRewrite *> materializationOps; |
2988 | SetVector<UnresolvedMaterializationRewrite *> necessaryMaterializations; |
2989 | computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl, |
2990 | inverseMapping&: *inverseMapping, necessaryMaterializations); |
2991 | |
2992 | // Once computed, legalize any necessary materializations. |
2993 | for (auto *mat : necessaryMaterializations) { |
2994 | if (failed(result: legalizeUnresolvedMaterialization( |
2995 | mat&: *mat, materializationOps, rewriter, rewriterImpl, inverseMapping&: *inverseMapping))) |
2996 | return failure(); |
2997 | } |
2998 | return success(); |
2999 | } |
3000 | |
3001 | LogicalResult OperationConverter::legalizeErasedResult( |
3002 | Operation *op, OpResult result, |
3003 | ConversionPatternRewriterImpl &rewriterImpl) { |
3004 | // If the operation result was replaced with null, all of the uses of this |
3005 | // value should be replaced. |
3006 | auto liveUserIt = llvm::find_if_not(Range: result.getUsers(), P: [&](Operation *user) { |
3007 | return rewriterImpl.isOpIgnored(op: user); |
3008 | }); |
3009 | if (liveUserIt != result.user_end()) { |
3010 | InFlightDiagnostic diag = op->emitError(message: "failed to legalize operation '" ) |
3011 | << op->getName() << "' marked as erased" ; |
3012 | diag.attachNote(noteLoc: liveUserIt->getLoc()) |
3013 | << "found live user of result #" << result.getResultNumber() << ": " |
3014 | << *liveUserIt; |
3015 | return failure(); |
3016 | } |
3017 | return success(); |
3018 | } |
3019 | |
3020 | /// Finds a user of the given value, or of any other value that the given value |
3021 | /// replaced, that was not replaced in the conversion process. |
3022 | static Operation *findLiveUserOfReplaced( |
3023 | Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, |
3024 | const DenseMap<Value, SmallVector<Value>> &inverseMapping) { |
3025 | SmallVector<Value> worklist(1, initialValue); |
3026 | while (!worklist.empty()) { |
3027 | Value value = worklist.pop_back_val(); |
3028 | |
3029 | // Walk the users of this value to see if there are any live users that |
3030 | // weren't replaced during conversion. |
3031 | auto liveUserIt = llvm::find_if_not(Range: value.getUsers(), P: [&](Operation *user) { |
3032 | return rewriterImpl.isOpIgnored(op: user); |
3033 | }); |
3034 | if (liveUserIt != value.user_end()) |
3035 | return *liveUserIt; |
3036 | auto mapIt = inverseMapping.find(Val: value); |
3037 | if (mapIt != inverseMapping.end()) |
3038 | worklist.append(RHS: mapIt->second); |
3039 | } |
3040 | return nullptr; |
3041 | } |
3042 | |
3043 | LogicalResult OperationConverter::legalizeChangedResultType( |
3044 | Operation *op, OpResult result, Value newValue, |
3045 | const TypeConverter *replConverter, ConversionPatternRewriter &rewriter, |
3046 | ConversionPatternRewriterImpl &rewriterImpl, |
3047 | const DenseMap<Value, SmallVector<Value>> &inverseMapping) { |
3048 | Operation *liveUser = |
3049 | findLiveUserOfReplaced(initialValue: result, rewriterImpl, inverseMapping); |
3050 | if (!liveUser) |
3051 | return success(); |
3052 | |
3053 | // Functor used to emit a conversion error for a failed materialization. |
3054 | auto emitConversionError = [&] { |
3055 | InFlightDiagnostic diag = op->emitError() |
3056 | << "failed to materialize conversion for result #" |
3057 | << result.getResultNumber() << " of operation '" |
3058 | << op->getName() |
3059 | << "' that remained live after conversion" ; |
3060 | diag.attachNote(noteLoc: liveUser->getLoc()) |
3061 | << "see existing live user here: " << *liveUser; |
3062 | return failure(); |
3063 | }; |
3064 | |
3065 | // If the replacement has a type converter, attempt to materialize a |
3066 | // conversion back to the original type. |
3067 | if (!replConverter) |
3068 | return emitConversionError(); |
3069 | |
3070 | // Materialize a conversion for this live result value. |
3071 | Type resultType = result.getType(); |
3072 | Value convertedValue = replConverter->materializeSourceConversion( |
3073 | builder&: rewriter, loc: op->getLoc(), resultType, inputs: newValue); |
3074 | if (!convertedValue) |
3075 | return emitConversionError(); |
3076 | |
3077 | rewriterImpl.mapping.map(oldVal: result, newVal: convertedValue); |
3078 | return success(); |
3079 | } |
3080 | |
3081 | //===----------------------------------------------------------------------===// |
3082 | // Type Conversion |
3083 | //===----------------------------------------------------------------------===// |
3084 | |
3085 | void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo, |
3086 | ArrayRef<Type> types) { |
3087 | assert(!types.empty() && "expected valid types" ); |
3088 | remapInput(origInputNo, /*newInputNo=*/argTypes.size(), newInputCount: types.size()); |
3089 | addInputs(types); |
3090 | } |
3091 | |
3092 | void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) { |
3093 | assert(!types.empty() && |
3094 | "1->0 type remappings don't need to be added explicitly" ); |
3095 | argTypes.append(in_start: types.begin(), in_end: types.end()); |
3096 | } |
3097 | |
3098 | void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, |
3099 | unsigned newInputNo, |
3100 | unsigned newInputCount) { |
3101 | assert(!remappedInputs[origInputNo] && "input has already been remapped" ); |
3102 | assert(newInputCount != 0 && "expected valid input count" ); |
3103 | remappedInputs[origInputNo] = |
3104 | InputMapping{.inputNo: newInputNo, .size: newInputCount, /*replacementValue=*/nullptr}; |
3105 | } |
3106 | |
3107 | void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, |
3108 | Value replacementValue) { |
3109 | assert(!remappedInputs[origInputNo] && "input has already been remapped" ); |
3110 | remappedInputs[origInputNo] = |
3111 | InputMapping{.inputNo: origInputNo, /*size=*/0, .replacementValue: replacementValue}; |
3112 | } |
3113 | |
3114 | LogicalResult TypeConverter::convertType(Type t, |
3115 | SmallVectorImpl<Type> &results) const { |
3116 | { |
3117 | std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex, |
3118 | std::defer_lock); |
3119 | if (t.getContext()->isMultithreadingEnabled()) |
3120 | cacheReadLock.lock(); |
3121 | auto existingIt = cachedDirectConversions.find(Val: t); |
3122 | if (existingIt != cachedDirectConversions.end()) { |
3123 | if (existingIt->second) |
3124 | results.push_back(Elt: existingIt->second); |
3125 | return success(isSuccess: existingIt->second != nullptr); |
3126 | } |
3127 | auto multiIt = cachedMultiConversions.find(Val: t); |
3128 | if (multiIt != cachedMultiConversions.end()) { |
3129 | results.append(in_start: multiIt->second.begin(), in_end: multiIt->second.end()); |
3130 | return success(); |
3131 | } |
3132 | } |
3133 | // Walk the added converters in reverse order to apply the most recently |
3134 | // registered first. |
3135 | size_t currentCount = results.size(); |
3136 | |
3137 | std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex, |
3138 | std::defer_lock); |
3139 | |
3140 | for (const ConversionCallbackFn &converter : llvm::reverse(C: conversions)) { |
3141 | if (std::optional<LogicalResult> result = converter(t, results)) { |
3142 | if (t.getContext()->isMultithreadingEnabled()) |
3143 | cacheWriteLock.lock(); |
3144 | if (!succeeded(result: *result)) { |
3145 | cachedDirectConversions.try_emplace(Key: t, Args: nullptr); |
3146 | return failure(); |
3147 | } |
3148 | auto newTypes = ArrayRef<Type>(results).drop_front(N: currentCount); |
3149 | if (newTypes.size() == 1) |
3150 | cachedDirectConversions.try_emplace(Key: t, Args: newTypes.front()); |
3151 | else |
3152 | cachedMultiConversions.try_emplace(Key: t, Args: llvm::to_vector<2>(Range&: newTypes)); |
3153 | return success(); |
3154 | } |
3155 | } |
3156 | return failure(); |
3157 | } |
3158 | |
3159 | Type TypeConverter::convertType(Type t) const { |
3160 | // Use the multi-type result version to convert the type. |
3161 | SmallVector<Type, 1> results; |
3162 | if (failed(result: convertType(t, results))) |
3163 | return nullptr; |
3164 | |
3165 | // Check to ensure that only one type was produced. |
3166 | return results.size() == 1 ? results.front() : nullptr; |
3167 | } |
3168 | |
3169 | LogicalResult |
3170 | TypeConverter::convertTypes(TypeRange types, |
3171 | SmallVectorImpl<Type> &results) const { |
3172 | for (Type type : types) |
3173 | if (failed(result: convertType(t: type, results))) |
3174 | return failure(); |
3175 | return success(); |
3176 | } |
3177 | |
3178 | bool TypeConverter::isLegal(Type type) const { |
3179 | return convertType(t: type) == type; |
3180 | } |
3181 | bool TypeConverter::isLegal(Operation *op) const { |
3182 | return isLegal(range: op->getOperandTypes()) && isLegal(range: op->getResultTypes()); |
3183 | } |
3184 | |
3185 | bool TypeConverter::isLegal(Region *region) const { |
3186 | return llvm::all_of(Range&: *region, P: [this](Block &block) { |
3187 | return isLegal(range: block.getArgumentTypes()); |
3188 | }); |
3189 | } |
3190 | |
3191 | bool TypeConverter::isSignatureLegal(FunctionType ty) const { |
3192 | return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults())); |
3193 | } |
3194 | |
3195 | LogicalResult |
3196 | TypeConverter::convertSignatureArg(unsigned inputNo, Type type, |
3197 | SignatureConversion &result) const { |
3198 | // Try to convert the given input type. |
3199 | SmallVector<Type, 1> convertedTypes; |
3200 | if (failed(result: convertType(t: type, results&: convertedTypes))) |
3201 | return failure(); |
3202 | |
3203 | // If this argument is being dropped, there is nothing left to do. |
3204 | if (convertedTypes.empty()) |
3205 | return success(); |
3206 | |
3207 | // Otherwise, add the new inputs. |
3208 | result.addInputs(origInputNo: inputNo, types: convertedTypes); |
3209 | return success(); |
3210 | } |
3211 | LogicalResult |
3212 | TypeConverter::convertSignatureArgs(TypeRange types, |
3213 | SignatureConversion &result, |
3214 | unsigned origInputOffset) const { |
3215 | for (unsigned i = 0, e = types.size(); i != e; ++i) |
3216 | if (failed(result: convertSignatureArg(inputNo: origInputOffset + i, type: types[i], result))) |
3217 | return failure(); |
3218 | return success(); |
3219 | } |
3220 | |
3221 | Value TypeConverter::materializeConversion( |
3222 | ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder, |
3223 | Location loc, Type resultType, ValueRange inputs) const { |
3224 | for (const MaterializationCallbackFn &fn : llvm::reverse(C&: materializations)) |
3225 | if (std::optional<Value> result = fn(builder, resultType, inputs, loc)) |
3226 | return *result; |
3227 | return nullptr; |
3228 | } |
3229 | |
3230 | std::optional<TypeConverter::SignatureConversion> |
3231 | TypeConverter::convertBlockSignature(Block *block) const { |
3232 | SignatureConversion conversion(block->getNumArguments()); |
3233 | if (failed(result: convertSignatureArgs(types: block->getArgumentTypes(), result&: conversion))) |
3234 | return std::nullopt; |
3235 | return conversion; |
3236 | } |
3237 | |
3238 | //===----------------------------------------------------------------------===// |
3239 | // Type attribute conversion |
3240 | //===----------------------------------------------------------------------===// |
3241 | TypeConverter::AttributeConversionResult |
3242 | TypeConverter::AttributeConversionResult::result(Attribute attr) { |
3243 | return AttributeConversionResult(attr, resultTag); |
3244 | } |
3245 | |
3246 | TypeConverter::AttributeConversionResult |
3247 | TypeConverter::AttributeConversionResult::na() { |
3248 | return AttributeConversionResult(nullptr, naTag); |
3249 | } |
3250 | |
3251 | TypeConverter::AttributeConversionResult |
3252 | TypeConverter::AttributeConversionResult::abort() { |
3253 | return AttributeConversionResult(nullptr, abortTag); |
3254 | } |
3255 | |
3256 | bool TypeConverter::AttributeConversionResult::hasResult() const { |
3257 | return impl.getInt() == resultTag; |
3258 | } |
3259 | |
3260 | bool TypeConverter::AttributeConversionResult::isNa() const { |
3261 | return impl.getInt() == naTag; |
3262 | } |
3263 | |
3264 | bool TypeConverter::AttributeConversionResult::isAbort() const { |
3265 | return impl.getInt() == abortTag; |
3266 | } |
3267 | |
3268 | Attribute TypeConverter::AttributeConversionResult::getResult() const { |
3269 | assert(hasResult() && "Cannot get result from N/A or abort" ); |
3270 | return impl.getPointer(); |
3271 | } |
3272 | |
3273 | std::optional<Attribute> |
3274 | TypeConverter::convertTypeAttribute(Type type, Attribute attr) const { |
3275 | for (const TypeAttributeConversionCallbackFn &fn : |
3276 | llvm::reverse(C: typeAttributeConversions)) { |
3277 | AttributeConversionResult res = fn(type, attr); |
3278 | if (res.hasResult()) |
3279 | return res.getResult(); |
3280 | if (res.isAbort()) |
3281 | return std::nullopt; |
3282 | } |
3283 | return std::nullopt; |
3284 | } |
3285 | |
3286 | //===----------------------------------------------------------------------===// |
3287 | // FunctionOpInterfaceSignatureConversion |
3288 | //===----------------------------------------------------------------------===// |
3289 | |
3290 | static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, |
3291 | const TypeConverter &typeConverter, |
3292 | ConversionPatternRewriter &rewriter) { |
3293 | FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType()); |
3294 | if (!type) |
3295 | return failure(); |
3296 | |
3297 | // Convert the original function types. |
3298 | TypeConverter::SignatureConversion result(type.getNumInputs()); |
3299 | SmallVector<Type, 1> newResults; |
3300 | if (failed(typeConverter.convertSignatureArgs(types: type.getInputs(), result)) || |
3301 | failed(typeConverter.convertTypes(types: type.getResults(), results&: newResults)) || |
3302 | failed(rewriter.convertRegionTypes(region: &funcOp.getFunctionBody(), |
3303 | converter: typeConverter, entryConversion: &result))) |
3304 | return failure(); |
3305 | |
3306 | // Update the function signature in-place. |
3307 | auto newType = FunctionType::get(rewriter.getContext(), |
3308 | result.getConvertedTypes(), newResults); |
3309 | |
3310 | rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); }); |
3311 | |
3312 | return success(); |
3313 | } |
3314 | |
3315 | /// Create a default conversion pattern that rewrites the type signature of a |
3316 | /// FunctionOpInterface op. This only supports ops which use FunctionType to |
3317 | /// represent their type. |
3318 | namespace { |
3319 | struct FunctionOpInterfaceSignatureConversion : public ConversionPattern { |
3320 | FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName, |
3321 | MLIRContext *ctx, |
3322 | const TypeConverter &converter) |
3323 | : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {} |
3324 | |
3325 | LogicalResult |
3326 | matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/, |
3327 | ConversionPatternRewriter &rewriter) const override { |
3328 | FunctionOpInterface funcOp = cast<FunctionOpInterface>(op); |
3329 | return convertFuncOpTypes(funcOp, *typeConverter, rewriter); |
3330 | } |
3331 | }; |
3332 | |
3333 | struct AnyFunctionOpInterfaceSignatureConversion |
3334 | : public OpInterfaceConversionPattern<FunctionOpInterface> { |
3335 | using OpInterfaceConversionPattern::OpInterfaceConversionPattern; |
3336 | |
3337 | LogicalResult |
3338 | matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/, |
3339 | ConversionPatternRewriter &rewriter) const override { |
3340 | return convertFuncOpTypes(funcOp, *typeConverter, rewriter); |
3341 | } |
3342 | }; |
3343 | } // namespace |
3344 | |
3345 | FailureOr<Operation *> |
3346 | mlir::convertOpResultTypes(Operation *op, ValueRange operands, |
3347 | const TypeConverter &converter, |
3348 | ConversionPatternRewriter &rewriter) { |
3349 | assert(op && "Invalid op" ); |
3350 | Location loc = op->getLoc(); |
3351 | if (converter.isLegal(op)) |
3352 | return rewriter.notifyMatchFailure(arg&: loc, msg: "op already legal" ); |
3353 | |
3354 | OperationState newOp(loc, op->getName()); |
3355 | newOp.addOperands(newOperands: operands); |
3356 | |
3357 | SmallVector<Type> newResultTypes; |
3358 | if (failed(result: converter.convertTypes(types: op->getResultTypes(), results&: newResultTypes))) |
3359 | return rewriter.notifyMatchFailure(arg&: loc, msg: "couldn't convert return types" ); |
3360 | |
3361 | newOp.addTypes(newTypes: newResultTypes); |
3362 | newOp.addAttributes(newAttributes: op->getAttrs()); |
3363 | return rewriter.create(state: newOp); |
3364 | } |
3365 | |
3366 | void mlir::populateFunctionOpInterfaceTypeConversionPattern( |
3367 | StringRef functionLikeOpName, RewritePatternSet &patterns, |
3368 | const TypeConverter &converter) { |
3369 | patterns.add<FunctionOpInterfaceSignatureConversion>( |
3370 | arg&: functionLikeOpName, args: patterns.getContext(), args: converter); |
3371 | } |
3372 | |
3373 | void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern( |
3374 | RewritePatternSet &patterns, const TypeConverter &converter) { |
3375 | patterns.add<AnyFunctionOpInterfaceSignatureConversion>( |
3376 | arg: converter, args: patterns.getContext()); |
3377 | } |
3378 | |
3379 | //===----------------------------------------------------------------------===// |
3380 | // ConversionTarget |
3381 | //===----------------------------------------------------------------------===// |
3382 | |
3383 | void ConversionTarget::setOpAction(OperationName op, |
3384 | LegalizationAction action) { |
3385 | legalOperations[op].action = action; |
3386 | } |
3387 | |
3388 | void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames, |
3389 | LegalizationAction action) { |
3390 | for (StringRef dialect : dialectNames) |
3391 | legalDialects[dialect] = action; |
3392 | } |
3393 | |
3394 | auto ConversionTarget::getOpAction(OperationName op) const |
3395 | -> std::optional<LegalizationAction> { |
3396 | std::optional<LegalizationInfo> info = getOpInfo(op); |
3397 | return info ? info->action : std::optional<LegalizationAction>(); |
3398 | } |
3399 | |
3400 | auto ConversionTarget::isLegal(Operation *op) const |
3401 | -> std::optional<LegalOpDetails> { |
3402 | std::optional<LegalizationInfo> info = getOpInfo(op: op->getName()); |
3403 | if (!info) |
3404 | return std::nullopt; |
3405 | |
3406 | // Returns true if this operation instance is known to be legal. |
3407 | auto isOpLegal = [&] { |
3408 | // Handle dynamic legality either with the provided legality function. |
3409 | if (info->action == LegalizationAction::Dynamic) { |
3410 | std::optional<bool> result = info->legalityFn(op); |
3411 | if (result) |
3412 | return *result; |
3413 | } |
3414 | |
3415 | // Otherwise, the operation is only legal if it was marked 'Legal'. |
3416 | return info->action == LegalizationAction::Legal; |
3417 | }; |
3418 | if (!isOpLegal()) |
3419 | return std::nullopt; |
3420 | |
3421 | // This operation is legal, compute any additional legality information. |
3422 | LegalOpDetails legalityDetails; |
3423 | if (info->isRecursivelyLegal) { |
3424 | auto legalityFnIt = opRecursiveLegalityFns.find(Val: op->getName()); |
3425 | if (legalityFnIt != opRecursiveLegalityFns.end()) { |
3426 | legalityDetails.isRecursivelyLegal = |
3427 | legalityFnIt->second(op).value_or(u: true); |
3428 | } else { |
3429 | legalityDetails.isRecursivelyLegal = true; |
3430 | } |
3431 | } |
3432 | return legalityDetails; |
3433 | } |
3434 | |
3435 | bool ConversionTarget::isIllegal(Operation *op) const { |
3436 | std::optional<LegalizationInfo> info = getOpInfo(op: op->getName()); |
3437 | if (!info) |
3438 | return false; |
3439 | |
3440 | if (info->action == LegalizationAction::Dynamic) { |
3441 | std::optional<bool> result = info->legalityFn(op); |
3442 | if (!result) |
3443 | return false; |
3444 | |
3445 | return !(*result); |
3446 | } |
3447 | |
3448 | return info->action == LegalizationAction::Illegal; |
3449 | } |
3450 | |
3451 | static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks( |
3452 | ConversionTarget::DynamicLegalityCallbackFn oldCallback, |
3453 | ConversionTarget::DynamicLegalityCallbackFn newCallback) { |
3454 | if (!oldCallback) |
3455 | return newCallback; |
3456 | |
3457 | auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)]( |
3458 | Operation *op) -> std::optional<bool> { |
3459 | if (std::optional<bool> result = newCl(op)) |
3460 | return *result; |
3461 | |
3462 | return oldCl(op); |
3463 | }; |
3464 | return chain; |
3465 | } |
3466 | |
3467 | void ConversionTarget::setLegalityCallback( |
3468 | OperationName name, const DynamicLegalityCallbackFn &callback) { |
3469 | assert(callback && "expected valid legality callback" ); |
3470 | auto *infoIt = legalOperations.find(Key: name); |
3471 | assert(infoIt != legalOperations.end() && |
3472 | infoIt->second.action == LegalizationAction::Dynamic && |
3473 | "expected operation to already be marked as dynamically legal" ); |
3474 | infoIt->second.legalityFn = |
3475 | composeLegalityCallbacks(oldCallback: std::move(infoIt->second.legalityFn), newCallback: callback); |
3476 | } |
3477 | |
3478 | void ConversionTarget::markOpRecursivelyLegal( |
3479 | OperationName name, const DynamicLegalityCallbackFn &callback) { |
3480 | auto *infoIt = legalOperations.find(Key: name); |
3481 | assert(infoIt != legalOperations.end() && |
3482 | infoIt->second.action != LegalizationAction::Illegal && |
3483 | "expected operation to already be marked as legal" ); |
3484 | infoIt->second.isRecursivelyLegal = true; |
3485 | if (callback) |
3486 | opRecursiveLegalityFns[name] = composeLegalityCallbacks( |
3487 | oldCallback: std::move(opRecursiveLegalityFns[name]), newCallback: callback); |
3488 | else |
3489 | opRecursiveLegalityFns.erase(Val: name); |
3490 | } |
3491 | |
3492 | void ConversionTarget::setLegalityCallback( |
3493 | ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) { |
3494 | assert(callback && "expected valid legality callback" ); |
3495 | for (StringRef dialect : dialects) |
3496 | dialectLegalityFns[dialect] = composeLegalityCallbacks( |
3497 | oldCallback: std::move(dialectLegalityFns[dialect]), newCallback: callback); |
3498 | } |
3499 | |
3500 | void ConversionTarget::setLegalityCallback( |
3501 | const DynamicLegalityCallbackFn &callback) { |
3502 | assert(callback && "expected valid legality callback" ); |
3503 | unknownLegalityFn = composeLegalityCallbacks(oldCallback: unknownLegalityFn, newCallback: callback); |
3504 | } |
3505 | |
3506 | auto ConversionTarget::getOpInfo(OperationName op) const |
3507 | -> std::optional<LegalizationInfo> { |
3508 | // Check for info for this specific operation. |
3509 | const auto *it = legalOperations.find(Key: op); |
3510 | if (it != legalOperations.end()) |
3511 | return it->second; |
3512 | // Check for info for the parent dialect. |
3513 | auto dialectIt = legalDialects.find(Key: op.getDialectNamespace()); |
3514 | if (dialectIt != legalDialects.end()) { |
3515 | DynamicLegalityCallbackFn callback; |
3516 | auto dialectFn = dialectLegalityFns.find(Key: op.getDialectNamespace()); |
3517 | if (dialectFn != dialectLegalityFns.end()) |
3518 | callback = dialectFn->second; |
3519 | return LegalizationInfo{.action: dialectIt->second, /*isRecursivelyLegal=*/false, |
3520 | .legalityFn: callback}; |
3521 | } |
3522 | // Otherwise, check if we mark unknown operations as dynamic. |
3523 | if (unknownLegalityFn) |
3524 | return LegalizationInfo{.action: LegalizationAction::Dynamic, |
3525 | /*isRecursivelyLegal=*/false, .legalityFn: unknownLegalityFn}; |
3526 | return std::nullopt; |
3527 | } |
3528 | |
3529 | #if MLIR_ENABLE_PDL_IN_PATTERNMATCH |
3530 | //===----------------------------------------------------------------------===// |
3531 | // PDL Configuration |
3532 | //===----------------------------------------------------------------------===// |
3533 | |
3534 | void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) { |
3535 | auto &rewriterImpl = |
3536 | static_cast<ConversionPatternRewriter &>(rewriter).getImpl(); |
3537 | rewriterImpl.currentTypeConverter = getTypeConverter(); |
3538 | } |
3539 | |
3540 | void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) { |
3541 | auto &rewriterImpl = |
3542 | static_cast<ConversionPatternRewriter &>(rewriter).getImpl(); |
3543 | rewriterImpl.currentTypeConverter = nullptr; |
3544 | } |
3545 | |
3546 | /// Remap the given value using the rewriter and the type converter in the |
3547 | /// provided config. |
3548 | static FailureOr<SmallVector<Value>> |
3549 | pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) { |
3550 | SmallVector<Value> mappedValues; |
3551 | if (failed(result: rewriter.getRemappedValues(keys: values, results&: mappedValues))) |
3552 | return failure(); |
3553 | return std::move(mappedValues); |
3554 | } |
3555 | |
3556 | void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) { |
3557 | patterns.getPDLPatterns().registerRewriteFunction( |
3558 | name: "convertValue" , |
3559 | rewriteFn: [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> { |
3560 | auto results = pdllConvertValues( |
3561 | rewriter&: static_cast<ConversionPatternRewriter &>(rewriter), values: value); |
3562 | if (failed(result: results)) |
3563 | return failure(); |
3564 | return results->front(); |
3565 | }); |
3566 | patterns.getPDLPatterns().registerRewriteFunction( |
3567 | name: "convertValues" , rewriteFn: [](PatternRewriter &rewriter, ValueRange values) { |
3568 | return pdllConvertValues( |
3569 | rewriter&: static_cast<ConversionPatternRewriter &>(rewriter), values); |
3570 | }); |
3571 | patterns.getPDLPatterns().registerRewriteFunction( |
3572 | name: "convertType" , |
3573 | rewriteFn: [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> { |
3574 | auto &rewriterImpl = |
3575 | static_cast<ConversionPatternRewriter &>(rewriter).getImpl(); |
3576 | if (const TypeConverter *converter = |
3577 | rewriterImpl.currentTypeConverter) { |
3578 | if (Type newType = converter->convertType(t: type)) |
3579 | return newType; |
3580 | return failure(); |
3581 | } |
3582 | return type; |
3583 | }); |
3584 | patterns.getPDLPatterns().registerRewriteFunction( |
3585 | name: "convertTypes" , |
3586 | rewriteFn: [](PatternRewriter &rewriter, |
3587 | TypeRange types) -> FailureOr<SmallVector<Type>> { |
3588 | auto &rewriterImpl = |
3589 | static_cast<ConversionPatternRewriter &>(rewriter).getImpl(); |
3590 | const TypeConverter *converter = rewriterImpl.currentTypeConverter; |
3591 | if (!converter) |
3592 | return SmallVector<Type>(types); |
3593 | |
3594 | SmallVector<Type> remappedTypes; |
3595 | if (failed(result: converter->convertTypes(types, results&: remappedTypes))) |
3596 | return failure(); |
3597 | return std::move(remappedTypes); |
3598 | }); |
3599 | } |
3600 | #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH |
3601 | |
3602 | //===----------------------------------------------------------------------===// |
3603 | // Op Conversion Entry Points |
3604 | //===----------------------------------------------------------------------===// |
3605 | |
3606 | //===----------------------------------------------------------------------===// |
3607 | // Partial Conversion |
3608 | |
3609 | LogicalResult mlir::applyPartialConversion( |
3610 | ArrayRef<Operation *> ops, const ConversionTarget &target, |
3611 | const FrozenRewritePatternSet &patterns, ConversionConfig config) { |
3612 | OperationConverter opConverter(target, patterns, config, |
3613 | OpConversionMode::Partial); |
3614 | return opConverter.convertOperations(ops); |
3615 | } |
3616 | LogicalResult |
3617 | mlir::applyPartialConversion(Operation *op, const ConversionTarget &target, |
3618 | const FrozenRewritePatternSet &patterns, |
3619 | ConversionConfig config) { |
3620 | return applyPartialConversion(ops: llvm::ArrayRef(op), target, patterns, config); |
3621 | } |
3622 | |
3623 | //===----------------------------------------------------------------------===// |
3624 | // Full Conversion |
3625 | |
3626 | LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops, |
3627 | const ConversionTarget &target, |
3628 | const FrozenRewritePatternSet &patterns, |
3629 | ConversionConfig config) { |
3630 | OperationConverter opConverter(target, patterns, config, |
3631 | OpConversionMode::Full); |
3632 | return opConverter.convertOperations(ops); |
3633 | } |
3634 | LogicalResult mlir::applyFullConversion(Operation *op, |
3635 | const ConversionTarget &target, |
3636 | const FrozenRewritePatternSet &patterns, |
3637 | ConversionConfig config) { |
3638 | return applyFullConversion(ops: llvm::ArrayRef(op), target, patterns, config); |
3639 | } |
3640 | |
3641 | //===----------------------------------------------------------------------===// |
3642 | // Analysis Conversion |
3643 | |
3644 | LogicalResult mlir::applyAnalysisConversion( |
3645 | ArrayRef<Operation *> ops, ConversionTarget &target, |
3646 | const FrozenRewritePatternSet &patterns, ConversionConfig config) { |
3647 | OperationConverter opConverter(target, patterns, config, |
3648 | OpConversionMode::Analysis); |
3649 | return opConverter.convertOperations(ops); |
3650 | } |
3651 | LogicalResult |
3652 | mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target, |
3653 | const FrozenRewritePatternSet &patterns, |
3654 | ConversionConfig config) { |
3655 | return applyAnalysisConversion(ops: llvm::ArrayRef(op), target, patterns, config); |
3656 | } |
3657 | |