| 1 | //===- Builders.h - Helpers for constructing MLIR Classes -------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #ifndef MLIR_IR_BUILDERS_H |
| 10 | #define MLIR_IR_BUILDERS_H |
| 11 | |
| 12 | #include "mlir/IR/OpDefinition.h" |
| 13 | #include "llvm/Support/Compiler.h" |
| 14 | #include <optional> |
| 15 | |
| 16 | namespace mlir { |
| 17 | |
| 18 | class AffineExpr; |
| 19 | class IRMapping; |
| 20 | class UnknownLoc; |
| 21 | class FileLineColLoc; |
| 22 | class FileLineColRange; |
| 23 | class Type; |
| 24 | class IntegerType; |
| 25 | class FloatType; |
| 26 | class FunctionType; |
| 27 | class IndexType; |
| 28 | class MemRefType; |
| 29 | class VectorType; |
| 30 | class RankedTensorType; |
| 31 | class UnrankedTensorType; |
| 32 | class TupleType; |
| 33 | class NoneType; |
| 34 | class BoolAttr; |
| 35 | class IntegerAttr; |
| 36 | class FloatAttr; |
| 37 | class StringAttr; |
| 38 | class TypeAttr; |
| 39 | class ArrayAttr; |
| 40 | class SymbolRefAttr; |
| 41 | class ElementsAttr; |
| 42 | class DenseElementsAttr; |
| 43 | class DenseIntElementsAttr; |
| 44 | class AffineMapAttr; |
| 45 | class AffineMap; |
| 46 | class UnitAttr; |
| 47 | |
| 48 | /// This class is a general helper class for creating context-global objects |
| 49 | /// like types, attributes, and affine expressions. |
| 50 | class Builder { |
| 51 | public: |
| 52 | explicit Builder(MLIRContext *context) : context(context) {} |
| 53 | explicit Builder(Operation *op) : Builder(op->getContext()) {} |
| 54 | |
| 55 | MLIRContext *getContext() const { return context; } |
| 56 | |
| 57 | // Locations. |
| 58 | Location getUnknownLoc(); |
| 59 | Location getFusedLoc(ArrayRef<Location> locs, |
| 60 | Attribute metadata = Attribute()); |
| 61 | |
| 62 | // Types. |
| 63 | FloatType getF8E8M0Type(); |
| 64 | FloatType getBF16Type(); |
| 65 | FloatType getF16Type(); |
| 66 | FloatType getTF32Type(); |
| 67 | FloatType getF32Type(); |
| 68 | FloatType getF64Type(); |
| 69 | FloatType getF80Type(); |
| 70 | FloatType getF128Type(); |
| 71 | |
| 72 | IndexType getIndexType(); |
| 73 | |
| 74 | IntegerType getI1Type(); |
| 75 | IntegerType getI2Type(); |
| 76 | IntegerType getI4Type(); |
| 77 | IntegerType getI8Type(); |
| 78 | IntegerType getI16Type(); |
| 79 | IntegerType getI32Type(); |
| 80 | IntegerType getI64Type(); |
| 81 | IntegerType getIntegerType(unsigned width); |
| 82 | IntegerType getIntegerType(unsigned width, bool isSigned); |
| 83 | FunctionType getFunctionType(TypeRange inputs, TypeRange results); |
| 84 | TupleType getTupleType(TypeRange elementTypes); |
| 85 | NoneType getNoneType(); |
| 86 | |
| 87 | /// Get or construct an instance of the type `Ty` with provided arguments. |
| 88 | template <typename Ty, typename... Args> |
| 89 | Ty getType(Args &&...args) { |
| 90 | return Ty::get(context, std::forward<Args>(args)...); |
| 91 | } |
| 92 | |
| 93 | /// Get or construct an instance of the attribute `Attr` with provided |
| 94 | /// arguments. |
| 95 | template <typename Attr, typename... Args> |
| 96 | Attr getAttr(Args &&...args) { |
| 97 | return Attr::get(context, std::forward<Args>(args)...); |
| 98 | } |
| 99 | |
| 100 | // Attributes. |
| 101 | NamedAttribute getNamedAttr(StringRef name, Attribute val); |
| 102 | |
| 103 | UnitAttr getUnitAttr(); |
| 104 | BoolAttr getBoolAttr(bool value); |
| 105 | DictionaryAttr getDictionaryAttr(ArrayRef<NamedAttribute> value); |
| 106 | IntegerAttr getIntegerAttr(Type type, int64_t value); |
| 107 | IntegerAttr getIntegerAttr(Type type, const APInt &value); |
| 108 | FloatAttr getFloatAttr(Type type, double value); |
| 109 | FloatAttr getFloatAttr(Type type, const APFloat &value); |
| 110 | StringAttr getStringAttr(const Twine &bytes); |
| 111 | ArrayAttr getArrayAttr(ArrayRef<Attribute> value); |
| 112 | |
| 113 | // Returns a 0-valued attribute of the given `type`. This function only |
| 114 | // supports boolean, integer, and 16-/32-/64-bit float types, and vector or |
| 115 | // ranked tensor of them. Returns null attribute otherwise. |
| 116 | TypedAttr getZeroAttr(Type type); |
| 117 | // Returns a 1-valued attribute of the given `type`. |
| 118 | // Type constraints are the same as `getZeroAttr`. |
| 119 | TypedAttr getOneAttr(Type type); |
| 120 | |
| 121 | // Convenience methods for fixed types. |
| 122 | FloatAttr getF16FloatAttr(float value); |
| 123 | FloatAttr getF32FloatAttr(float value); |
| 124 | FloatAttr getF64FloatAttr(double value); |
| 125 | |
| 126 | IntegerAttr getI8IntegerAttr(int8_t value); |
| 127 | IntegerAttr getI16IntegerAttr(int16_t value); |
| 128 | IntegerAttr getI32IntegerAttr(int32_t value); |
| 129 | IntegerAttr getI64IntegerAttr(int64_t value); |
| 130 | IntegerAttr getIndexAttr(int64_t value); |
| 131 | |
| 132 | /// Signed and unsigned integer attribute getters. |
| 133 | IntegerAttr getSI32IntegerAttr(int32_t value); |
| 134 | IntegerAttr getUI32IntegerAttr(uint32_t value); |
| 135 | |
| 136 | /// Vector-typed DenseIntElementsAttr getters. `values` must not be empty. |
| 137 | DenseIntElementsAttr getBoolVectorAttr(ArrayRef<bool> values); |
| 138 | DenseIntElementsAttr getI32VectorAttr(ArrayRef<int32_t> values); |
| 139 | DenseIntElementsAttr getI64VectorAttr(ArrayRef<int64_t> values); |
| 140 | DenseIntElementsAttr getIndexVectorAttr(ArrayRef<int64_t> values); |
| 141 | |
| 142 | DenseFPElementsAttr getF32VectorAttr(ArrayRef<float> values); |
| 143 | DenseFPElementsAttr getF64VectorAttr(ArrayRef<double> values); |
| 144 | |
| 145 | /// Tensor-typed DenseIntElementsAttr getters. `values` can be empty. |
| 146 | /// These are generally preferable for representing general lists of integers |
| 147 | /// as attributes. |
| 148 | DenseIntElementsAttr getI32TensorAttr(ArrayRef<int32_t> values); |
| 149 | DenseIntElementsAttr getI64TensorAttr(ArrayRef<int64_t> values); |
| 150 | DenseIntElementsAttr getIndexTensorAttr(ArrayRef<int64_t> values); |
| 151 | |
| 152 | /// Tensor-typed DenseArrayAttr getters. |
| 153 | DenseBoolArrayAttr getDenseBoolArrayAttr(ArrayRef<bool> values); |
| 154 | DenseI8ArrayAttr getDenseI8ArrayAttr(ArrayRef<int8_t> values); |
| 155 | DenseI16ArrayAttr getDenseI16ArrayAttr(ArrayRef<int16_t> values); |
| 156 | DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef<int32_t> values); |
| 157 | DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef<int64_t> values); |
| 158 | DenseF32ArrayAttr getDenseF32ArrayAttr(ArrayRef<float> values); |
| 159 | DenseF64ArrayAttr getDenseF64ArrayAttr(ArrayRef<double> values); |
| 160 | |
| 161 | ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values); |
| 162 | ArrayAttr getBoolArrayAttr(ArrayRef<bool> values); |
| 163 | ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values); |
| 164 | ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values); |
| 165 | ArrayAttr getIndexArrayAttr(ArrayRef<int64_t> values); |
| 166 | ArrayAttr getF32ArrayAttr(ArrayRef<float> values); |
| 167 | ArrayAttr getF64ArrayAttr(ArrayRef<double> values); |
| 168 | ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values); |
| 169 | ArrayAttr getTypeArrayAttr(TypeRange values); |
| 170 | |
| 171 | // Affine expressions and affine maps. |
| 172 | AffineExpr getAffineDimExpr(unsigned position); |
| 173 | AffineExpr getAffineSymbolExpr(unsigned position); |
| 174 | AffineExpr getAffineConstantExpr(int64_t constant); |
| 175 | |
| 176 | // Special cases of affine maps and integer sets |
| 177 | /// Returns a zero result affine map with no dimensions or symbols: () -> (). |
| 178 | AffineMap getEmptyAffineMap(); |
| 179 | /// Returns a single constant result affine map with 0 dimensions and 0 |
| 180 | /// symbols. One constant result: () -> (val). |
| 181 | AffineMap getConstantAffineMap(int64_t val); |
| 182 | // One dimension id identity map: (i) -> (i). |
| 183 | AffineMap getDimIdentityMap(); |
| 184 | // Multi-dimensional identity map: (d0, d1, d2) -> (d0, d1, d2). |
| 185 | AffineMap getMultiDimIdentityMap(unsigned rank); |
| 186 | // One symbol identity map: ()[s] -> (s). |
| 187 | AffineMap getSymbolIdentityMap(); |
| 188 | |
| 189 | /// Returns a map that shifts its (single) input dimension by 'shift'. |
| 190 | /// (d0) -> (d0 + shift) |
| 191 | AffineMap getSingleDimShiftAffineMap(int64_t shift); |
| 192 | |
| 193 | /// Returns an affine map that is a translation (shift) of all result |
| 194 | /// expressions in 'map' by 'shift'. |
| 195 | /// Eg: input: (d0, d1)[s0] -> (d0, d1 + s0), shift = 2 |
| 196 | /// returns: (d0, d1)[s0] -> (d0 + 2, d1 + s0 + 2) |
| 197 | AffineMap getShiftedAffineMap(AffineMap map, int64_t shift); |
| 198 | |
| 199 | protected: |
| 200 | MLIRContext *context; |
| 201 | }; |
| 202 | |
| 203 | /// This class helps build Operations. Operations that are created are |
| 204 | /// automatically inserted at an insertion point. The builder is copyable. |
| 205 | class OpBuilder : public Builder { |
| 206 | public: |
| 207 | class InsertPoint; |
| 208 | struct Listener; |
| 209 | |
| 210 | /// Create a builder with the given context. |
| 211 | explicit OpBuilder(MLIRContext *ctx, Listener *listener = nullptr) |
| 212 | : Builder(ctx), listener(listener) {} |
| 213 | |
| 214 | /// Create a builder and set the insertion point to the start of the region. |
| 215 | explicit OpBuilder(Region *region, Listener *listener = nullptr) |
| 216 | : OpBuilder(region->getContext(), listener) { |
| 217 | if (!region->empty()) |
| 218 | setInsertionPointToStart(®ion->front()); |
| 219 | } |
| 220 | explicit OpBuilder(Region ®ion, Listener *listener = nullptr) |
| 221 | : OpBuilder(®ion, listener) {} |
| 222 | |
| 223 | /// Create a builder and set insertion point to the given operation, which |
| 224 | /// will cause subsequent insertions to go right before it. |
| 225 | explicit OpBuilder(Operation *op, Listener *listener = nullptr) |
| 226 | : OpBuilder(op->getContext(), listener) { |
| 227 | setInsertionPoint(op); |
| 228 | } |
| 229 | |
| 230 | OpBuilder(Block *block, Block::iterator insertPoint, |
| 231 | Listener *listener = nullptr) |
| 232 | : OpBuilder(block->getParent()->getContext(), listener) { |
| 233 | setInsertionPoint(block, insertPoint); |
| 234 | } |
| 235 | |
| 236 | /// Create a builder and set the insertion point to before the first operation |
| 237 | /// in the block but still inside the block. |
| 238 | static OpBuilder atBlockBegin(Block *block, Listener *listener = nullptr) { |
| 239 | return OpBuilder(block, block->begin(), listener); |
| 240 | } |
| 241 | |
| 242 | /// Create a builder and set the insertion point to after the last operation |
| 243 | /// in the block but still inside the block. |
| 244 | static OpBuilder atBlockEnd(Block *block, Listener *listener = nullptr) { |
| 245 | return OpBuilder(block, block->end(), listener); |
| 246 | } |
| 247 | |
| 248 | /// Create a builder and set the insertion point to before the block |
| 249 | /// terminator. |
| 250 | static OpBuilder atBlockTerminator(Block *block, |
| 251 | Listener *listener = nullptr) { |
| 252 | auto *terminator = block->getTerminator(); |
| 253 | assert(terminator != nullptr && "the block has no terminator" ); |
| 254 | return OpBuilder(block, Block::iterator(terminator), listener); |
| 255 | } |
| 256 | |
| 257 | //===--------------------------------------------------------------------===// |
| 258 | // Listeners |
| 259 | //===--------------------------------------------------------------------===// |
| 260 | |
| 261 | /// Base class for listeners. |
| 262 | struct ListenerBase { |
| 263 | /// The kind of listener. |
| 264 | enum class Kind { |
| 265 | /// OpBuilder::Listener or user-derived class. |
| 266 | OpBuilderListener = 0, |
| 267 | |
| 268 | /// RewriterBase::Listener or user-derived class. |
| 269 | RewriterBaseListener = 1 |
| 270 | }; |
| 271 | |
| 272 | Kind getKind() const { return kind; } |
| 273 | |
| 274 | protected: |
| 275 | ListenerBase(Kind kind) : kind(kind) {} |
| 276 | |
| 277 | private: |
| 278 | const Kind kind; |
| 279 | }; |
| 280 | |
| 281 | /// This class represents a listener that may be used to hook into various |
| 282 | /// actions within an OpBuilder. |
| 283 | struct Listener : public ListenerBase { |
| 284 | Listener() : ListenerBase(ListenerBase::Kind::OpBuilderListener) {} |
| 285 | |
| 286 | virtual ~Listener() = default; |
| 287 | |
| 288 | /// Notify the listener that the specified operation was inserted. |
| 289 | /// |
| 290 | /// * If the operation was moved, then `previous` is the previous location |
| 291 | /// of the op. |
| 292 | /// * If the operation was unlinked before it was inserted, then `previous` |
| 293 | /// is empty. |
| 294 | /// |
| 295 | /// Note: Creating an (unlinked) op does not trigger this notification. |
| 296 | virtual void notifyOperationInserted(Operation *op, InsertPoint previous) {} |
| 297 | |
| 298 | /// Notify the listener that the specified block was inserted. |
| 299 | /// |
| 300 | /// * If the block was moved, then `previous` and `previousIt` are the |
| 301 | /// previous location of the block. |
| 302 | /// * If the block was unlinked before it was inserted, then `previous` |
| 303 | /// is "nullptr". |
| 304 | /// |
| 305 | /// Note: Creating an (unlinked) block does not trigger this notification. |
| 306 | virtual void notifyBlockInserted(Block *block, Region *previous, |
| 307 | Region::iterator previousIt) {} |
| 308 | |
| 309 | protected: |
| 310 | Listener(Kind kind) : ListenerBase(kind) {} |
| 311 | }; |
| 312 | |
| 313 | /// Sets the listener of this builder to the one provided. |
| 314 | void setListener(Listener *newListener) { listener = newListener; } |
| 315 | |
| 316 | /// Returns the current listener of this builder, or nullptr if this builder |
| 317 | /// doesn't have a listener. |
| 318 | Listener *getListener() const { return listener; } |
| 319 | |
| 320 | //===--------------------------------------------------------------------===// |
| 321 | // Insertion Point Management |
| 322 | //===--------------------------------------------------------------------===// |
| 323 | |
| 324 | /// This class represents a saved insertion point. |
| 325 | class InsertPoint { |
| 326 | public: |
| 327 | /// Creates a new insertion point which doesn't point to anything. |
| 328 | InsertPoint() = default; |
| 329 | |
| 330 | /// Creates a new insertion point at the given location. |
| 331 | InsertPoint(Block *insertBlock, Block::iterator insertPt) |
| 332 | : block(insertBlock), point(insertPt) {} |
| 333 | |
| 334 | /// Returns true if this insert point is set. |
| 335 | bool isSet() const { return (block != nullptr); } |
| 336 | |
| 337 | Block *getBlock() const { return block; } |
| 338 | Block::iterator getPoint() const { return point; } |
| 339 | |
| 340 | private: |
| 341 | Block *block = nullptr; |
| 342 | Block::iterator point; |
| 343 | }; |
| 344 | |
| 345 | /// RAII guard to reset the insertion point of the builder when destroyed. |
| 346 | class InsertionGuard { |
| 347 | public: |
| 348 | InsertionGuard(OpBuilder &builder) |
| 349 | : builder(&builder), ip(builder.saveInsertionPoint()) {} |
| 350 | |
| 351 | ~InsertionGuard() { |
| 352 | if (builder) |
| 353 | builder->restoreInsertionPoint(ip); |
| 354 | } |
| 355 | |
| 356 | InsertionGuard(const InsertionGuard &) = delete; |
| 357 | InsertionGuard &operator=(const InsertionGuard &) = delete; |
| 358 | |
| 359 | /// Implement the move constructor to clear the builder field of `other`. |
| 360 | /// That way it does not restore the insertion point upon destruction as |
| 361 | /// that should be done exclusively by the just constructed InsertionGuard. |
| 362 | InsertionGuard(InsertionGuard &&other) noexcept |
| 363 | : builder(other.builder), ip(other.ip) { |
| 364 | other.builder = nullptr; |
| 365 | } |
| 366 | |
| 367 | InsertionGuard &operator=(InsertionGuard &&other) = delete; |
| 368 | |
| 369 | private: |
| 370 | OpBuilder *builder; |
| 371 | OpBuilder::InsertPoint ip; |
| 372 | }; |
| 373 | |
| 374 | /// Reset the insertion point to no location. Creating an operation without a |
| 375 | /// set insertion point is an error, but this can still be useful when the |
| 376 | /// current insertion point a builder refers to is being removed. |
| 377 | void clearInsertionPoint() { |
| 378 | this->block = nullptr; |
| 379 | insertPoint = Block::iterator(); |
| 380 | } |
| 381 | |
| 382 | /// Return a saved insertion point. |
| 383 | InsertPoint saveInsertionPoint() const { |
| 384 | return InsertPoint(getInsertionBlock(), getInsertionPoint()); |
| 385 | } |
| 386 | |
| 387 | /// Restore the insert point to a previously saved point. |
| 388 | void restoreInsertionPoint(InsertPoint ip) { |
| 389 | if (ip.isSet()) |
| 390 | setInsertionPoint(block: ip.getBlock(), insertPoint: ip.getPoint()); |
| 391 | else |
| 392 | clearInsertionPoint(); |
| 393 | } |
| 394 | |
| 395 | /// Set the insertion point to the specified location. |
| 396 | void setInsertionPoint(Block *block, Block::iterator insertPoint) { |
| 397 | // TODO: check that insertPoint is in this rather than some other block. |
| 398 | this->block = block; |
| 399 | this->insertPoint = insertPoint; |
| 400 | } |
| 401 | |
| 402 | /// Sets the insertion point to the specified operation, which will cause |
| 403 | /// subsequent insertions to go right before it. |
| 404 | void setInsertionPoint(Operation *op) { |
| 405 | setInsertionPoint(block: op->getBlock(), insertPoint: Block::iterator(op)); |
| 406 | } |
| 407 | |
| 408 | /// Sets the insertion point to the node after the specified operation, which |
| 409 | /// will cause subsequent insertions to go right after it. |
| 410 | void setInsertionPointAfter(Operation *op) { |
| 411 | setInsertionPoint(block: op->getBlock(), insertPoint: ++Block::iterator(op)); |
| 412 | } |
| 413 | |
| 414 | /// Sets the insertion point to the node after the specified value. If value |
| 415 | /// has a defining operation, sets the insertion point to the node after such |
| 416 | /// defining operation. This will cause subsequent insertions to go right |
| 417 | /// after it. Otherwise, value is a BlockArgument. Sets the insertion point to |
| 418 | /// the start of its block. |
| 419 | void setInsertionPointAfterValue(Value val) { |
| 420 | if (Operation *op = val.getDefiningOp()) { |
| 421 | setInsertionPointAfter(op); |
| 422 | } else { |
| 423 | auto blockArg = llvm::cast<BlockArgument>(Val&: val); |
| 424 | setInsertionPointToStart(blockArg.getOwner()); |
| 425 | } |
| 426 | } |
| 427 | |
| 428 | /// Sets the insertion point to the start of the specified block. |
| 429 | void setInsertionPointToStart(Block *block) { |
| 430 | setInsertionPoint(block, insertPoint: block->begin()); |
| 431 | } |
| 432 | |
| 433 | /// Sets the insertion point to the end of the specified block. |
| 434 | void setInsertionPointToEnd(Block *block) { |
| 435 | setInsertionPoint(block, insertPoint: block->end()); |
| 436 | } |
| 437 | |
| 438 | /// Return the block the current insertion point belongs to. Note that the |
| 439 | /// insertion point is not necessarily the end of the block. |
| 440 | Block *getInsertionBlock() const { return block; } |
| 441 | |
| 442 | /// Returns the current insertion point of the builder. |
| 443 | Block::iterator getInsertionPoint() const { return insertPoint; } |
| 444 | |
| 445 | /// Returns the current block of the builder. |
| 446 | Block *getBlock() const { return block; } |
| 447 | |
| 448 | //===--------------------------------------------------------------------===// |
| 449 | // Block Creation |
| 450 | //===--------------------------------------------------------------------===// |
| 451 | |
| 452 | /// Add new block with 'argTypes' arguments and set the insertion point to the |
| 453 | /// end of it. The block is inserted at the provided insertion point of |
| 454 | /// 'parent'. `locs` contains the locations of the inserted arguments, and |
| 455 | /// should match the size of `argTypes`. |
| 456 | Block *createBlock(Region *parent, Region::iterator insertPt = {}, |
| 457 | TypeRange argTypes = std::nullopt, |
| 458 | ArrayRef<Location> locs = std::nullopt); |
| 459 | |
| 460 | /// Add new block with 'argTypes' arguments and set the insertion point to the |
| 461 | /// end of it. The block is placed before 'insertBefore'. `locs` contains the |
| 462 | /// locations of the inserted arguments, and should match the size of |
| 463 | /// `argTypes`. |
| 464 | Block *createBlock(Block *insertBefore, TypeRange argTypes = std::nullopt, |
| 465 | ArrayRef<Location> locs = std::nullopt); |
| 466 | |
| 467 | //===--------------------------------------------------------------------===// |
| 468 | // Operation Creation |
| 469 | //===--------------------------------------------------------------------===// |
| 470 | |
| 471 | /// Insert the given operation at the current insertion point and return it. |
| 472 | Operation *insert(Operation *op); |
| 473 | |
| 474 | /// Creates an operation given the fields represented as an OperationState. |
| 475 | Operation *create(const OperationState &state); |
| 476 | |
| 477 | /// Creates an operation with the given fields. |
| 478 | Operation *create(Location loc, StringAttr opName, ValueRange operands, |
| 479 | TypeRange types = {}, |
| 480 | ArrayRef<NamedAttribute> attributes = {}, |
| 481 | BlockRange successors = {}, |
| 482 | MutableArrayRef<std::unique_ptr<Region>> regions = {}); |
| 483 | |
| 484 | private: |
| 485 | /// Helper for sanity checking preconditions for create* methods below. |
| 486 | template <typename OpT> |
| 487 | RegisteredOperationName getCheckRegisteredInfo(MLIRContext *ctx) { |
| 488 | std::optional<RegisteredOperationName> opName = |
| 489 | RegisteredOperationName::lookup(TypeID::get<OpT>(), ctx); |
| 490 | if (LLVM_UNLIKELY(!opName)) { |
| 491 | llvm::report_fatal_error( |
| 492 | "Building op `" + OpT::getOperationName() + |
| 493 | "` but it isn't known in this MLIRContext: the dialect may not " |
| 494 | "be loaded or this operation hasn't been added by the dialect. See " |
| 495 | "also https://mlir.llvm.org/getting_started/Faq/" |
| 496 | "#registered-loaded-dependent-whats-up-with-dialects-management" ); |
| 497 | } |
| 498 | return *opName; |
| 499 | } |
| 500 | |
| 501 | public: |
| 502 | /// Create an operation of specific op type at the current insertion point. |
| 503 | template <typename OpTy, typename... Args> |
| 504 | OpTy create(Location location, Args &&...args) { |
| 505 | OperationState state(location, |
| 506 | getCheckRegisteredInfo<OpTy>(location.getContext())); |
| 507 | OpTy::build(*this, state, std::forward<Args>(args)...); |
| 508 | auto *op = create(state); |
| 509 | auto result = dyn_cast<OpTy>(op); |
| 510 | assert(result && "builder didn't return the right type" ); |
| 511 | return result; |
| 512 | } |
| 513 | |
| 514 | /// Create an operation of specific op type at the current insertion point, |
| 515 | /// and immediately try to fold it. This functions populates 'results' with |
| 516 | /// the results of the operation. |
| 517 | template <typename OpTy, typename... Args> |
| 518 | void createOrFold(SmallVectorImpl<Value> &results, Location location, |
| 519 | Args &&...args) { |
| 520 | // Create the operation without using 'create' as we want to control when |
| 521 | // the listener is notified. |
| 522 | OperationState state(location, |
| 523 | getCheckRegisteredInfo<OpTy>(location.getContext())); |
| 524 | OpTy::build(*this, state, std::forward<Args>(args)...); |
| 525 | Operation *op = Operation::create(state); |
| 526 | if (block) |
| 527 | block->getOperations().insert(where: insertPoint, New: op); |
| 528 | |
| 529 | // Attempt to fold the operation. |
| 530 | if (succeeded(Result: tryFold(op, results)) && !results.empty()) { |
| 531 | // Erase the operation, if the fold removed the need for this operation. |
| 532 | // Note: The fold already populated the results in this case. |
| 533 | op->erase(); |
| 534 | return; |
| 535 | } |
| 536 | |
| 537 | ResultRange opResults = op->getResults(); |
| 538 | results.assign(in_start: opResults.begin(), in_end: opResults.end()); |
| 539 | if (block && listener) |
| 540 | listener->notifyOperationInserted(op, /*previous=*/previous: {}); |
| 541 | } |
| 542 | |
| 543 | /// Overload to create or fold a single result operation. |
| 544 | template <typename OpTy, typename... Args> |
| 545 | std::enable_if_t<OpTy::template hasTrait<OpTrait::OneResult>(), Value> |
| 546 | createOrFold(Location location, Args &&...args) { |
| 547 | SmallVector<Value, 1> results; |
| 548 | createOrFold<OpTy>(results, location, std::forward<Args>(args)...); |
| 549 | return results.front(); |
| 550 | } |
| 551 | |
| 552 | /// Overload to create or fold a zero result operation. |
| 553 | template <typename OpTy, typename... Args> |
| 554 | std::enable_if_t<OpTy::template hasTrait<OpTrait::ZeroResults>(), OpTy> |
| 555 | createOrFold(Location location, Args &&...args) { |
| 556 | auto op = create<OpTy>(location, std::forward<Args>(args)...); |
| 557 | SmallVector<Value, 0> unused; |
| 558 | (void)tryFold(op: op.getOperation(), results&: unused); |
| 559 | |
| 560 | // Folding cannot remove a zero-result operation, so for convenience we |
| 561 | // continue to return it. |
| 562 | return op; |
| 563 | } |
| 564 | |
| 565 | /// Attempts to fold the given operation and places new results within |
| 566 | /// `results`. Returns success if the operation was folded, failure otherwise. |
| 567 | /// If the fold was in-place, `results` will not be filled. Optionally, newly |
| 568 | /// materialized constant operations can be returned to the caller. |
| 569 | /// |
| 570 | /// Note: This function does not erase the operation on a successful fold. |
| 571 | LogicalResult |
| 572 | tryFold(Operation *op, SmallVectorImpl<Value> &results, |
| 573 | SmallVectorImpl<Operation *> *materializedConstants = nullptr); |
| 574 | |
| 575 | /// Creates a deep copy of the specified operation, remapping any operands |
| 576 | /// that use values outside of the operation using the map that is provided |
| 577 | /// ( leaving them alone if no entry is present). Replaces references to |
| 578 | /// cloned sub-operations to the corresponding operation that is copied, |
| 579 | /// and adds those mappings to the map. |
| 580 | Operation *clone(Operation &op, IRMapping &mapper); |
| 581 | Operation *clone(Operation &op); |
| 582 | |
| 583 | /// Creates a deep copy of this operation but keep the operation regions |
| 584 | /// empty. Operands are remapped using `mapper` (if present), and `mapper` is |
| 585 | /// updated to contain the results. |
| 586 | Operation *cloneWithoutRegions(Operation &op, IRMapping &mapper) { |
| 587 | return insert(op: op.cloneWithoutRegions(mapper)); |
| 588 | } |
| 589 | Operation *cloneWithoutRegions(Operation &op) { |
| 590 | return insert(op: op.cloneWithoutRegions()); |
| 591 | } |
| 592 | template <typename OpT> |
| 593 | OpT cloneWithoutRegions(OpT op) { |
| 594 | return cast<OpT>(cloneWithoutRegions(*op.getOperation())); |
| 595 | } |
| 596 | |
| 597 | /// Clone the blocks that belong to "region" before the given position in |
| 598 | /// another region "parent". The two regions must be different. The caller is |
| 599 | /// responsible for creating or updating the operation transferring flow of |
| 600 | /// control to the region and passing it the correct block arguments. |
| 601 | void cloneRegionBefore(Region ®ion, Region &parent, |
| 602 | Region::iterator before, IRMapping &mapping); |
| 603 | void cloneRegionBefore(Region ®ion, Region &parent, |
| 604 | Region::iterator before); |
| 605 | void cloneRegionBefore(Region ®ion, Block *before); |
| 606 | |
| 607 | protected: |
| 608 | /// The optional listener for events of this builder. |
| 609 | Listener *listener; |
| 610 | |
| 611 | private: |
| 612 | /// The current block this builder is inserting into. |
| 613 | Block *block = nullptr; |
| 614 | /// The insertion point within the block that this builder is inserting |
| 615 | /// before. |
| 616 | Block::iterator insertPoint; |
| 617 | }; |
| 618 | |
| 619 | } // namespace mlir |
| 620 | |
| 621 | #endif |
| 622 | |