1 | //===- Builders.h - Helpers for constructing MLIR Classes -------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_IR_BUILDERS_H |
10 | #define MLIR_IR_BUILDERS_H |
11 | |
12 | #include "mlir/IR/OpDefinition.h" |
13 | #include "llvm/Support/Compiler.h" |
14 | #include <optional> |
15 | |
16 | namespace mlir { |
17 | |
18 | class AffineExpr; |
19 | class IRMapping; |
20 | class UnknownLoc; |
21 | class FileLineColLoc; |
22 | class Type; |
23 | class PrimitiveType; |
24 | class IntegerType; |
25 | class FloatType; |
26 | class FunctionType; |
27 | class IndexType; |
28 | class MemRefType; |
29 | class VectorType; |
30 | class RankedTensorType; |
31 | class UnrankedTensorType; |
32 | class TupleType; |
33 | class NoneType; |
34 | class BoolAttr; |
35 | class IntegerAttr; |
36 | class FloatAttr; |
37 | class StringAttr; |
38 | class TypeAttr; |
39 | class ArrayAttr; |
40 | class SymbolRefAttr; |
41 | class ElementsAttr; |
42 | class DenseElementsAttr; |
43 | class DenseIntElementsAttr; |
44 | class AffineMapAttr; |
45 | class AffineMap; |
46 | class UnitAttr; |
47 | |
48 | /// This class is a general helper class for creating context-global objects |
49 | /// like types, attributes, and affine expressions. |
50 | class Builder { |
51 | public: |
52 | explicit Builder(MLIRContext *context) : context(context) {} |
53 | explicit Builder(Operation *op) : Builder(op->getContext()) {} |
54 | |
55 | MLIRContext *getContext() const { return context; } |
56 | |
57 | // Locations. |
58 | Location getUnknownLoc(); |
59 | Location getFusedLoc(ArrayRef<Location> locs, |
60 | Attribute metadata = Attribute()); |
61 | |
62 | // Types. |
63 | FloatType getFloat8E5M2Type(); |
64 | FloatType getFloat8E4M3FNType(); |
65 | FloatType getFloat8E5M2FNUZType(); |
66 | FloatType getFloat8E4M3FNUZType(); |
67 | FloatType getFloat8E4M3B11FNUZType(); |
68 | FloatType getBF16Type(); |
69 | FloatType getF16Type(); |
70 | FloatType getTF32Type(); |
71 | FloatType getF32Type(); |
72 | FloatType getF64Type(); |
73 | FloatType getF80Type(); |
74 | FloatType getF128Type(); |
75 | |
76 | IndexType getIndexType(); |
77 | |
78 | IntegerType getI1Type(); |
79 | IntegerType getI2Type(); |
80 | IntegerType getI4Type(); |
81 | IntegerType getI8Type(); |
82 | IntegerType getI16Type(); |
83 | IntegerType getI32Type(); |
84 | IntegerType getI64Type(); |
85 | IntegerType getIntegerType(unsigned width); |
86 | IntegerType getIntegerType(unsigned width, bool isSigned); |
87 | FunctionType getFunctionType(TypeRange inputs, TypeRange results); |
88 | TupleType getTupleType(TypeRange elementTypes); |
89 | NoneType getNoneType(); |
90 | |
91 | /// Get or construct an instance of the type `Ty` with provided arguments. |
92 | template <typename Ty, typename... Args> |
93 | Ty getType(Args &&...args) { |
94 | return Ty::get(context, std::forward<Args>(args)...); |
95 | } |
96 | |
97 | /// Get or construct an instance of the attribute `Attr` with provided |
98 | /// arguments. |
99 | template <typename Attr, typename... Args> |
100 | Attr getAttr(Args &&...args) { |
101 | return Attr::get(context, std::forward<Args>(args)...); |
102 | } |
103 | |
104 | // Attributes. |
105 | NamedAttribute getNamedAttr(StringRef name, Attribute val); |
106 | |
107 | UnitAttr getUnitAttr(); |
108 | BoolAttr getBoolAttr(bool value); |
109 | DictionaryAttr getDictionaryAttr(ArrayRef<NamedAttribute> value); |
110 | IntegerAttr getIntegerAttr(Type type, int64_t value); |
111 | IntegerAttr getIntegerAttr(Type type, const APInt &value); |
112 | FloatAttr getFloatAttr(Type type, double value); |
113 | FloatAttr getFloatAttr(Type type, const APFloat &value); |
114 | StringAttr getStringAttr(const Twine &bytes); |
115 | ArrayAttr getArrayAttr(ArrayRef<Attribute> value); |
116 | |
117 | // Returns a 0-valued attribute of the given `type`. This function only |
118 | // supports boolean, integer, and 16-/32-/64-bit float types, and vector or |
119 | // ranked tensor of them. Returns null attribute otherwise. |
120 | TypedAttr getZeroAttr(Type type); |
121 | // Returns a 1-valued attribute of the given `type`. |
122 | // Type constraints are the same as `getZeroAttr`. |
123 | TypedAttr getOneAttr(Type type); |
124 | |
125 | // Convenience methods for fixed types. |
126 | FloatAttr getF16FloatAttr(float value); |
127 | FloatAttr getF32FloatAttr(float value); |
128 | FloatAttr getF64FloatAttr(double value); |
129 | |
130 | IntegerAttr getI8IntegerAttr(int8_t value); |
131 | IntegerAttr getI16IntegerAttr(int16_t value); |
132 | IntegerAttr getI32IntegerAttr(int32_t value); |
133 | IntegerAttr getI64IntegerAttr(int64_t value); |
134 | IntegerAttr getIndexAttr(int64_t value); |
135 | |
136 | /// Signed and unsigned integer attribute getters. |
137 | IntegerAttr getSI32IntegerAttr(int32_t value); |
138 | IntegerAttr getUI32IntegerAttr(uint32_t value); |
139 | |
140 | /// Vector-typed DenseIntElementsAttr getters. `values` must not be empty. |
141 | DenseIntElementsAttr getBoolVectorAttr(ArrayRef<bool> values); |
142 | DenseIntElementsAttr getI32VectorAttr(ArrayRef<int32_t> values); |
143 | DenseIntElementsAttr getI64VectorAttr(ArrayRef<int64_t> values); |
144 | DenseIntElementsAttr getIndexVectorAttr(ArrayRef<int64_t> values); |
145 | |
146 | DenseFPElementsAttr getF32VectorAttr(ArrayRef<float> values); |
147 | DenseFPElementsAttr getF64VectorAttr(ArrayRef<double> values); |
148 | |
149 | /// Tensor-typed DenseIntElementsAttr getters. `values` can be empty. |
150 | /// These are generally preferable for representing general lists of integers |
151 | /// as attributes. |
152 | DenseIntElementsAttr getI32TensorAttr(ArrayRef<int32_t> values); |
153 | DenseIntElementsAttr getI64TensorAttr(ArrayRef<int64_t> values); |
154 | DenseIntElementsAttr getIndexTensorAttr(ArrayRef<int64_t> values); |
155 | |
156 | /// Tensor-typed DenseArrayAttr getters. |
157 | DenseBoolArrayAttr getDenseBoolArrayAttr(ArrayRef<bool> values); |
158 | DenseI8ArrayAttr getDenseI8ArrayAttr(ArrayRef<int8_t> values); |
159 | DenseI16ArrayAttr getDenseI16ArrayAttr(ArrayRef<int16_t> values); |
160 | DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef<int32_t> values); |
161 | DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef<int64_t> values); |
162 | DenseF32ArrayAttr getDenseF32ArrayAttr(ArrayRef<float> values); |
163 | DenseF64ArrayAttr getDenseF64ArrayAttr(ArrayRef<double> values); |
164 | |
165 | ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values); |
166 | ArrayAttr getBoolArrayAttr(ArrayRef<bool> values); |
167 | ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values); |
168 | ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values); |
169 | ArrayAttr getIndexArrayAttr(ArrayRef<int64_t> values); |
170 | ArrayAttr getF32ArrayAttr(ArrayRef<float> values); |
171 | ArrayAttr getF64ArrayAttr(ArrayRef<double> values); |
172 | ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values); |
173 | ArrayAttr getTypeArrayAttr(TypeRange values); |
174 | |
175 | // Affine expressions and affine maps. |
176 | AffineExpr getAffineDimExpr(unsigned position); |
177 | AffineExpr getAffineSymbolExpr(unsigned position); |
178 | AffineExpr getAffineConstantExpr(int64_t constant); |
179 | |
180 | // Special cases of affine maps and integer sets |
181 | /// Returns a zero result affine map with no dimensions or symbols: () -> (). |
182 | AffineMap getEmptyAffineMap(); |
183 | /// Returns a single constant result affine map with 0 dimensions and 0 |
184 | /// symbols. One constant result: () -> (val). |
185 | AffineMap getConstantAffineMap(int64_t val); |
186 | // One dimension id identity map: (i) -> (i). |
187 | AffineMap getDimIdentityMap(); |
188 | // Multi-dimensional identity map: (d0, d1, d2) -> (d0, d1, d2). |
189 | AffineMap getMultiDimIdentityMap(unsigned rank); |
190 | // One symbol identity map: ()[s] -> (s). |
191 | AffineMap getSymbolIdentityMap(); |
192 | |
193 | /// Returns a map that shifts its (single) input dimension by 'shift'. |
194 | /// (d0) -> (d0 + shift) |
195 | AffineMap getSingleDimShiftAffineMap(int64_t shift); |
196 | |
197 | /// Returns an affine map that is a translation (shift) of all result |
198 | /// expressions in 'map' by 'shift'. |
199 | /// Eg: input: (d0, d1)[s0] -> (d0, d1 + s0), shift = 2 |
200 | /// returns: (d0, d1)[s0] -> (d0 + 2, d1 + s0 + 2) |
201 | AffineMap getShiftedAffineMap(AffineMap map, int64_t shift); |
202 | |
203 | protected: |
204 | MLIRContext *context; |
205 | }; |
206 | |
207 | /// This class helps build Operations. Operations that are created are |
208 | /// automatically inserted at an insertion point. The builder is copyable. |
209 | class OpBuilder : public Builder { |
210 | public: |
211 | class InsertPoint; |
212 | struct Listener; |
213 | |
214 | /// Create a builder with the given context. |
215 | explicit OpBuilder(MLIRContext *ctx, Listener *listener = nullptr) |
216 | : Builder(ctx), listener(listener) {} |
217 | |
218 | /// Create a builder and set the insertion point to the start of the region. |
219 | explicit OpBuilder(Region *region, Listener *listener = nullptr) |
220 | : OpBuilder(region->getContext(), listener) { |
221 | if (!region->empty()) |
222 | setInsertionPoint(block: ®ion->front(), insertPoint: region->front().begin()); |
223 | } |
224 | explicit OpBuilder(Region ®ion, Listener *listener = nullptr) |
225 | : OpBuilder(®ion, listener) {} |
226 | |
227 | /// Create a builder and set insertion point to the given operation, which |
228 | /// will cause subsequent insertions to go right before it. |
229 | explicit OpBuilder(Operation *op, Listener *listener = nullptr) |
230 | : OpBuilder(op->getContext(), listener) { |
231 | setInsertionPoint(op); |
232 | } |
233 | |
234 | OpBuilder(Block *block, Block::iterator insertPoint, |
235 | Listener *listener = nullptr) |
236 | : OpBuilder(block->getParent()->getContext(), listener) { |
237 | setInsertionPoint(block, insertPoint); |
238 | } |
239 | |
240 | /// Create a builder and set the insertion point to before the first operation |
241 | /// in the block but still inside the block. |
242 | static OpBuilder atBlockBegin(Block *block, Listener *listener = nullptr) { |
243 | return OpBuilder(block, block->begin(), listener); |
244 | } |
245 | |
246 | /// Create a builder and set the insertion point to after the last operation |
247 | /// in the block but still inside the block. |
248 | static OpBuilder atBlockEnd(Block *block, Listener *listener = nullptr) { |
249 | return OpBuilder(block, block->end(), listener); |
250 | } |
251 | |
252 | /// Create a builder and set the insertion point to before the block |
253 | /// terminator. |
254 | static OpBuilder atBlockTerminator(Block *block, |
255 | Listener *listener = nullptr) { |
256 | auto *terminator = block->getTerminator(); |
257 | assert(terminator != nullptr && "the block has no terminator" ); |
258 | return OpBuilder(block, Block::iterator(terminator), listener); |
259 | } |
260 | |
261 | //===--------------------------------------------------------------------===// |
262 | // Listeners |
263 | //===--------------------------------------------------------------------===// |
264 | |
265 | /// Base class for listeners. |
266 | struct ListenerBase { |
267 | /// The kind of listener. |
268 | enum class Kind { |
269 | /// OpBuilder::Listener or user-derived class. |
270 | OpBuilderListener = 0, |
271 | |
272 | /// RewriterBase::Listener or user-derived class. |
273 | RewriterBaseListener = 1 |
274 | }; |
275 | |
276 | Kind getKind() const { return kind; } |
277 | |
278 | protected: |
279 | ListenerBase(Kind kind) : kind(kind) {} |
280 | |
281 | private: |
282 | const Kind kind; |
283 | }; |
284 | |
285 | /// This class represents a listener that may be used to hook into various |
286 | /// actions within an OpBuilder. |
287 | struct Listener : public ListenerBase { |
288 | Listener() : ListenerBase(ListenerBase::Kind::OpBuilderListener) {} |
289 | |
290 | virtual ~Listener() = default; |
291 | |
292 | /// Notify the listener that the specified operation was inserted. |
293 | /// |
294 | /// * If the operation was moved, then `previous` is the previous location |
295 | /// of the op. |
296 | /// * If the operation was unlinked before it was inserted, then `previous` |
297 | /// is empty. |
298 | /// |
299 | /// Note: Creating an (unlinked) op does not trigger this notification. |
300 | virtual void notifyOperationInserted(Operation *op, InsertPoint previous) {} |
301 | |
302 | /// Notify the listener that the specified block was inserted. |
303 | /// |
304 | /// * If the block was moved, then `previous` and `previousIt` are the |
305 | /// previous location of the block. |
306 | /// * If the block was unlinked before it was inserted, then `previous` |
307 | /// is "nullptr". |
308 | /// |
309 | /// Note: Creating an (unlinked) block does not trigger this notification. |
310 | virtual void notifyBlockInserted(Block *block, Region *previous, |
311 | Region::iterator previousIt) {} |
312 | |
313 | protected: |
314 | Listener(Kind kind) : ListenerBase(kind) {} |
315 | }; |
316 | |
317 | /// Sets the listener of this builder to the one provided. |
318 | void setListener(Listener *newListener) { listener = newListener; } |
319 | |
320 | /// Returns the current listener of this builder, or nullptr if this builder |
321 | /// doesn't have a listener. |
322 | Listener *getListener() const { return listener; } |
323 | |
324 | //===--------------------------------------------------------------------===// |
325 | // Insertion Point Management |
326 | //===--------------------------------------------------------------------===// |
327 | |
328 | /// This class represents a saved insertion point. |
329 | class InsertPoint { |
330 | public: |
331 | /// Creates a new insertion point which doesn't point to anything. |
332 | InsertPoint() = default; |
333 | |
334 | /// Creates a new insertion point at the given location. |
335 | InsertPoint(Block *insertBlock, Block::iterator insertPt) |
336 | : block(insertBlock), point(insertPt) {} |
337 | |
338 | /// Returns true if this insert point is set. |
339 | bool isSet() const { return (block != nullptr); } |
340 | |
341 | Block *getBlock() const { return block; } |
342 | Block::iterator getPoint() const { return point; } |
343 | |
344 | private: |
345 | Block *block = nullptr; |
346 | Block::iterator point; |
347 | }; |
348 | |
349 | /// RAII guard to reset the insertion point of the builder when destroyed. |
350 | class InsertionGuard { |
351 | public: |
352 | InsertionGuard(OpBuilder &builder) |
353 | : builder(&builder), ip(builder.saveInsertionPoint()) {} |
354 | |
355 | ~InsertionGuard() { |
356 | if (builder) |
357 | builder->restoreInsertionPoint(ip); |
358 | } |
359 | |
360 | InsertionGuard(const InsertionGuard &) = delete; |
361 | InsertionGuard &operator=(const InsertionGuard &) = delete; |
362 | |
363 | /// Implement the move constructor to clear the builder field of `other`. |
364 | /// That way it does not restore the insertion point upon destruction as |
365 | /// that should be done exclusively by the just constructed InsertionGuard. |
366 | InsertionGuard(InsertionGuard &&other) noexcept |
367 | : builder(other.builder), ip(other.ip) { |
368 | other.builder = nullptr; |
369 | } |
370 | |
371 | InsertionGuard &operator=(InsertionGuard &&other) = delete; |
372 | |
373 | private: |
374 | OpBuilder *builder; |
375 | OpBuilder::InsertPoint ip; |
376 | }; |
377 | |
378 | /// Reset the insertion point to no location. Creating an operation without a |
379 | /// set insertion point is an error, but this can still be useful when the |
380 | /// current insertion point a builder refers to is being removed. |
381 | void clearInsertionPoint() { |
382 | this->block = nullptr; |
383 | insertPoint = Block::iterator(); |
384 | } |
385 | |
386 | /// Return a saved insertion point. |
387 | InsertPoint saveInsertionPoint() const { |
388 | return InsertPoint(getInsertionBlock(), getInsertionPoint()); |
389 | } |
390 | |
391 | /// Restore the insert point to a previously saved point. |
392 | void restoreInsertionPoint(InsertPoint ip) { |
393 | if (ip.isSet()) |
394 | setInsertionPoint(block: ip.getBlock(), insertPoint: ip.getPoint()); |
395 | else |
396 | clearInsertionPoint(); |
397 | } |
398 | |
399 | /// Set the insertion point to the specified location. |
400 | void setInsertionPoint(Block *block, Block::iterator insertPoint) { |
401 | // TODO: check that insertPoint is in this rather than some other block. |
402 | this->block = block; |
403 | this->insertPoint = insertPoint; |
404 | } |
405 | |
406 | /// Sets the insertion point to the specified operation, which will cause |
407 | /// subsequent insertions to go right before it. |
408 | void setInsertionPoint(Operation *op) { |
409 | setInsertionPoint(block: op->getBlock(), insertPoint: Block::iterator(op)); |
410 | } |
411 | |
412 | /// Sets the insertion point to the node after the specified operation, which |
413 | /// will cause subsequent insertions to go right after it. |
414 | void setInsertionPointAfter(Operation *op) { |
415 | setInsertionPoint(block: op->getBlock(), insertPoint: ++Block::iterator(op)); |
416 | } |
417 | |
418 | /// Sets the insertion point to the node after the specified value. If value |
419 | /// has a defining operation, sets the insertion point to the node after such |
420 | /// defining operation. This will cause subsequent insertions to go right |
421 | /// after it. Otherwise, value is a BlockArgument. Sets the insertion point to |
422 | /// the start of its block. |
423 | void setInsertionPointAfterValue(Value val) { |
424 | if (Operation *op = val.getDefiningOp()) { |
425 | setInsertionPointAfter(op); |
426 | } else { |
427 | auto blockArg = llvm::cast<BlockArgument>(Val&: val); |
428 | setInsertionPointToStart(blockArg.getOwner()); |
429 | } |
430 | } |
431 | |
432 | /// Sets the insertion point to the start of the specified block. |
433 | void setInsertionPointToStart(Block *block) { |
434 | setInsertionPoint(block, insertPoint: block->begin()); |
435 | } |
436 | |
437 | /// Sets the insertion point to the end of the specified block. |
438 | void setInsertionPointToEnd(Block *block) { |
439 | setInsertionPoint(block, insertPoint: block->end()); |
440 | } |
441 | |
442 | /// Return the block the current insertion point belongs to. Note that the |
443 | /// insertion point is not necessarily the end of the block. |
444 | Block *getInsertionBlock() const { return block; } |
445 | |
446 | /// Returns the current insertion point of the builder. |
447 | Block::iterator getInsertionPoint() const { return insertPoint; } |
448 | |
449 | /// Returns the current block of the builder. |
450 | Block *getBlock() const { return block; } |
451 | |
452 | //===--------------------------------------------------------------------===// |
453 | // Block Creation |
454 | //===--------------------------------------------------------------------===// |
455 | |
456 | /// Add new block with 'argTypes' arguments and set the insertion point to the |
457 | /// end of it. The block is inserted at the provided insertion point of |
458 | /// 'parent'. `locs` contains the locations of the inserted arguments, and |
459 | /// should match the size of `argTypes`. |
460 | Block *createBlock(Region *parent, Region::iterator insertPt = {}, |
461 | TypeRange argTypes = std::nullopt, |
462 | ArrayRef<Location> locs = std::nullopt); |
463 | |
464 | /// Add new block with 'argTypes' arguments and set the insertion point to the |
465 | /// end of it. The block is placed before 'insertBefore'. `locs` contains the |
466 | /// locations of the inserted arguments, and should match the size of |
467 | /// `argTypes`. |
468 | Block *createBlock(Block *insertBefore, TypeRange argTypes = std::nullopt, |
469 | ArrayRef<Location> locs = std::nullopt); |
470 | |
471 | //===--------------------------------------------------------------------===// |
472 | // Operation Creation |
473 | //===--------------------------------------------------------------------===// |
474 | |
475 | /// Insert the given operation at the current insertion point and return it. |
476 | Operation *insert(Operation *op); |
477 | |
478 | /// Creates an operation given the fields represented as an OperationState. |
479 | Operation *create(const OperationState &state); |
480 | |
481 | /// Creates an operation with the given fields. |
482 | Operation *create(Location loc, StringAttr opName, ValueRange operands, |
483 | TypeRange types = {}, |
484 | ArrayRef<NamedAttribute> attributes = {}, |
485 | BlockRange successors = {}, |
486 | MutableArrayRef<std::unique_ptr<Region>> regions = {}); |
487 | |
488 | private: |
489 | /// Helper for sanity checking preconditions for create* methods below. |
490 | template <typename OpT> |
491 | RegisteredOperationName getCheckRegisteredInfo(MLIRContext *ctx) { |
492 | std::optional<RegisteredOperationName> opName = |
493 | RegisteredOperationName::lookup(TypeID::get<OpT>(), ctx); |
494 | if (LLVM_UNLIKELY(!opName)) { |
495 | llvm::report_fatal_error( |
496 | "Building op `" + OpT::getOperationName() + |
497 | "` but it isn't known in this MLIRContext: the dialect may not " |
498 | "be loaded or this operation hasn't been added by the dialect. See " |
499 | "also https://mlir.llvm.org/getting_started/Faq/" |
500 | "#registered-loaded-dependent-whats-up-with-dialects-management" ); |
501 | } |
502 | return *opName; |
503 | } |
504 | |
505 | public: |
506 | /// Create an operation of specific op type at the current insertion point. |
507 | template <typename OpTy, typename... Args> |
508 | OpTy create(Location location, Args &&...args) { |
509 | OperationState state(location, |
510 | getCheckRegisteredInfo<OpTy>(location.getContext())); |
511 | OpTy::build(*this, state, std::forward<Args>(args)...); |
512 | auto *op = create(state); |
513 | auto result = dyn_cast<OpTy>(op); |
514 | assert(result && "builder didn't return the right type" ); |
515 | return result; |
516 | } |
517 | |
518 | /// Create an operation of specific op type at the current insertion point, |
519 | /// and immediately try to fold it. This functions populates 'results' with |
520 | /// the results of the operation. |
521 | template <typename OpTy, typename... Args> |
522 | void createOrFold(SmallVectorImpl<Value> &results, Location location, |
523 | Args &&...args) { |
524 | // Create the operation without using 'create' as we want to control when |
525 | // the listener is notified. |
526 | OperationState state(location, |
527 | getCheckRegisteredInfo<OpTy>(location.getContext())); |
528 | OpTy::build(*this, state, std::forward<Args>(args)...); |
529 | Operation *op = Operation::create(state); |
530 | if (block) |
531 | block->getOperations().insert(where: insertPoint, New: op); |
532 | |
533 | // Attempt to fold the operation. |
534 | if (succeeded(result: tryFold(op, results)) && !results.empty()) { |
535 | // Erase the operation, if the fold removed the need for this operation. |
536 | // Note: The fold already populated the results in this case. |
537 | op->erase(); |
538 | return; |
539 | } |
540 | |
541 | ResultRange opResults = op->getResults(); |
542 | results.assign(in_start: opResults.begin(), in_end: opResults.end()); |
543 | if (block && listener) |
544 | listener->notifyOperationInserted(op, /*previous=*/previous: {}); |
545 | } |
546 | |
547 | /// Overload to create or fold a single result operation. |
548 | template <typename OpTy, typename... Args> |
549 | std::enable_if_t<OpTy::template hasTrait<OpTrait::OneResult>(), Value> |
550 | createOrFold(Location location, Args &&...args) { |
551 | SmallVector<Value, 1> results; |
552 | createOrFold<OpTy>(results, location, std::forward<Args>(args)...); |
553 | return results.front(); |
554 | } |
555 | |
556 | /// Overload to create or fold a zero result operation. |
557 | template <typename OpTy, typename... Args> |
558 | std::enable_if_t<OpTy::template hasTrait<OpTrait::ZeroResults>(), OpTy> |
559 | createOrFold(Location location, Args &&...args) { |
560 | auto op = create<OpTy>(location, std::forward<Args>(args)...); |
561 | SmallVector<Value, 0> unused; |
562 | (void)tryFold(op: op.getOperation(), results&: unused); |
563 | |
564 | // Folding cannot remove a zero-result operation, so for convenience we |
565 | // continue to return it. |
566 | return op; |
567 | } |
568 | |
569 | /// Attempts to fold the given operation and places new results within |
570 | /// `results`. Returns success if the operation was folded, failure otherwise. |
571 | /// If the fold was in-place, `results` will not be filled. |
572 | /// Note: This function does not erase the operation on a successful fold. |
573 | LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results); |
574 | |
575 | /// Creates a deep copy of the specified operation, remapping any operands |
576 | /// that use values outside of the operation using the map that is provided |
577 | /// ( leaving them alone if no entry is present). Replaces references to |
578 | /// cloned sub-operations to the corresponding operation that is copied, |
579 | /// and adds those mappings to the map. |
580 | Operation *clone(Operation &op, IRMapping &mapper); |
581 | Operation *clone(Operation &op); |
582 | |
583 | /// Creates a deep copy of this operation but keep the operation regions |
584 | /// empty. Operands are remapped using `mapper` (if present), and `mapper` is |
585 | /// updated to contain the results. |
586 | Operation *cloneWithoutRegions(Operation &op, IRMapping &mapper) { |
587 | return insert(op: op.cloneWithoutRegions(mapper)); |
588 | } |
589 | Operation *cloneWithoutRegions(Operation &op) { |
590 | return insert(op: op.cloneWithoutRegions()); |
591 | } |
592 | template <typename OpT> |
593 | OpT cloneWithoutRegions(OpT op) { |
594 | return cast<OpT>(cloneWithoutRegions(*op.getOperation())); |
595 | } |
596 | |
597 | /// Clone the blocks that belong to "region" before the given position in |
598 | /// another region "parent". The two regions must be different. The caller is |
599 | /// responsible for creating or updating the operation transferring flow of |
600 | /// control to the region and passing it the correct block arguments. |
601 | void cloneRegionBefore(Region ®ion, Region &parent, |
602 | Region::iterator before, IRMapping &mapping); |
603 | void cloneRegionBefore(Region ®ion, Region &parent, |
604 | Region::iterator before); |
605 | void cloneRegionBefore(Region ®ion, Block *before); |
606 | |
607 | protected: |
608 | /// The optional listener for events of this builder. |
609 | Listener *listener; |
610 | |
611 | private: |
612 | /// The current block this builder is inserting into. |
613 | Block *block = nullptr; |
614 | /// The insertion point within the block that this builder is inserting |
615 | /// before. |
616 | Block::iterator insertPoint; |
617 | }; |
618 | |
619 | } // namespace mlir |
620 | |
621 | #endif |
622 | |