1 | //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/IR/Builders.h" |
10 | #include "mlir/IR/AffineExpr.h" |
11 | #include "mlir/IR/AffineMap.h" |
12 | #include "mlir/IR/BuiltinTypes.h" |
13 | #include "mlir/IR/Dialect.h" |
14 | #include "mlir/IR/IRMapping.h" |
15 | #include "mlir/IR/IntegerSet.h" |
16 | #include "mlir/IR/Matchers.h" |
17 | #include "mlir/IR/SymbolTable.h" |
18 | #include "llvm/ADT/SmallVectorExtras.h" |
19 | #include "llvm/Support/raw_ostream.h" |
20 | |
21 | using namespace mlir; |
22 | |
23 | //===----------------------------------------------------------------------===// |
24 | // Locations. |
25 | //===----------------------------------------------------------------------===// |
26 | |
27 | Location Builder::getUnknownLoc() { return UnknownLoc::get(context); } |
28 | |
29 | Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) { |
30 | return FusedLoc::get(locs, metadata, context); |
31 | } |
32 | |
33 | //===----------------------------------------------------------------------===// |
34 | // Types. |
35 | //===----------------------------------------------------------------------===// |
36 | |
37 | FloatType Builder::getF8E8M0Type() { return Float8E8M0FNUType::get(context); } |
38 | |
39 | FloatType Builder::getBF16Type() { return BFloat16Type::get(context); } |
40 | |
41 | FloatType Builder::getF16Type() { return Float16Type::get(context); } |
42 | |
43 | FloatType Builder::getTF32Type() { return FloatTF32Type::get(context); } |
44 | |
45 | FloatType Builder::getF32Type() { return Float32Type::get(context); } |
46 | |
47 | FloatType Builder::getF64Type() { return Float64Type::get(context); } |
48 | |
49 | FloatType Builder::getF80Type() { return Float80Type::get(context); } |
50 | |
51 | FloatType Builder::getF128Type() { return Float128Type::get(context); } |
52 | |
53 | IndexType Builder::getIndexType() { return IndexType::get(context); } |
54 | |
55 | IntegerType Builder::getI1Type() { return IntegerType::get(context, 1); } |
56 | |
57 | IntegerType Builder::getI2Type() { return IntegerType::get(context, 2); } |
58 | |
59 | IntegerType Builder::getI4Type() { return IntegerType::get(context, 4); } |
60 | |
61 | IntegerType Builder::getI8Type() { return IntegerType::get(context, 8); } |
62 | |
63 | IntegerType Builder::getI16Type() { return IntegerType::get(context, 16); } |
64 | |
65 | IntegerType Builder::getI32Type() { return IntegerType::get(context, 32); } |
66 | |
67 | IntegerType Builder::getI64Type() { return IntegerType::get(context, 64); } |
68 | |
69 | IntegerType Builder::getIntegerType(unsigned width) { |
70 | return IntegerType::get(context, width); |
71 | } |
72 | |
73 | IntegerType Builder::getIntegerType(unsigned width, bool isSigned) { |
74 | return IntegerType::get( |
75 | context, width, isSigned ? IntegerType::Signed : IntegerType::Unsigned); |
76 | } |
77 | |
78 | FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) { |
79 | return FunctionType::get(context, inputs, results); |
80 | } |
81 | |
82 | TupleType Builder::getTupleType(TypeRange elementTypes) { |
83 | return TupleType::get(context, elementTypes); |
84 | } |
85 | |
86 | NoneType Builder::getNoneType() { return NoneType::get(context); } |
87 | |
88 | //===----------------------------------------------------------------------===// |
89 | // Attributes. |
90 | //===----------------------------------------------------------------------===// |
91 | |
92 | NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) { |
93 | return NamedAttribute(name, val); |
94 | } |
95 | |
96 | UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); } |
97 | |
98 | BoolAttr Builder::getBoolAttr(bool value) { |
99 | return BoolAttr::get(context, value); |
100 | } |
101 | |
102 | DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) { |
103 | return DictionaryAttr::get(context, value); |
104 | } |
105 | |
106 | IntegerAttr Builder::getIndexAttr(int64_t value) { |
107 | return IntegerAttr::get(getIndexType(), APInt(64, value)); |
108 | } |
109 | |
110 | IntegerAttr Builder::getI64IntegerAttr(int64_t value) { |
111 | return IntegerAttr::get(getIntegerType(64), APInt(64, value)); |
112 | } |
113 | |
114 | DenseIntElementsAttr Builder::getBoolVectorAttr(ArrayRef<bool> values) { |
115 | return DenseIntElementsAttr::get( |
116 | VectorType::get(static_cast<int64_t>(values.size()), getI1Type()), |
117 | values); |
118 | } |
119 | |
120 | DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) { |
121 | return DenseIntElementsAttr::get( |
122 | VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(32)), |
123 | values); |
124 | } |
125 | |
126 | DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef<int64_t> values) { |
127 | return DenseIntElementsAttr::get( |
128 | VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(64)), |
129 | values); |
130 | } |
131 | |
132 | DenseIntElementsAttr Builder::getIndexVectorAttr(ArrayRef<int64_t> values) { |
133 | return DenseIntElementsAttr::get( |
134 | VectorType::get(static_cast<int64_t>(values.size()), getIndexType()), |
135 | values); |
136 | } |
137 | |
138 | DenseFPElementsAttr Builder::getF32VectorAttr(ArrayRef<float> values) { |
139 | return DenseFPElementsAttr::get( |
140 | VectorType::get(static_cast<float>(values.size()), getF32Type()), values); |
141 | } |
142 | |
143 | DenseFPElementsAttr Builder::getF64VectorAttr(ArrayRef<double> values) { |
144 | return DenseFPElementsAttr::get( |
145 | VectorType::get(static_cast<double>(values.size()), getF64Type()), |
146 | values); |
147 | } |
148 | |
149 | DenseBoolArrayAttr Builder::getDenseBoolArrayAttr(ArrayRef<bool> values) { |
150 | return DenseBoolArrayAttr::get(context, values); |
151 | } |
152 | |
153 | DenseI8ArrayAttr Builder::getDenseI8ArrayAttr(ArrayRef<int8_t> values) { |
154 | return DenseI8ArrayAttr::get(context, values); |
155 | } |
156 | |
157 | DenseI16ArrayAttr Builder::getDenseI16ArrayAttr(ArrayRef<int16_t> values) { |
158 | return DenseI16ArrayAttr::get(context, values); |
159 | } |
160 | |
161 | DenseI32ArrayAttr Builder::getDenseI32ArrayAttr(ArrayRef<int32_t> values) { |
162 | return DenseI32ArrayAttr::get(context, values); |
163 | } |
164 | |
165 | DenseI64ArrayAttr Builder::getDenseI64ArrayAttr(ArrayRef<int64_t> values) { |
166 | return DenseI64ArrayAttr::get(context, values); |
167 | } |
168 | |
169 | DenseF32ArrayAttr Builder::getDenseF32ArrayAttr(ArrayRef<float> values) { |
170 | return DenseF32ArrayAttr::get(context, values); |
171 | } |
172 | |
173 | DenseF64ArrayAttr Builder::getDenseF64ArrayAttr(ArrayRef<double> values) { |
174 | return DenseF64ArrayAttr::get(context, values); |
175 | } |
176 | |
177 | DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef<int32_t> values) { |
178 | return DenseIntElementsAttr::get( |
179 | RankedTensorType::get(static_cast<int64_t>(values.size()), |
180 | getIntegerType(32)), |
181 | values); |
182 | } |
183 | |
184 | DenseIntElementsAttr Builder::getI64TensorAttr(ArrayRef<int64_t> values) { |
185 | return DenseIntElementsAttr::get( |
186 | RankedTensorType::get(static_cast<int64_t>(values.size()), |
187 | getIntegerType(64)), |
188 | values); |
189 | } |
190 | |
191 | DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) { |
192 | return DenseIntElementsAttr::get( |
193 | RankedTensorType::get(static_cast<int64_t>(values.size()), |
194 | getIndexType()), |
195 | values); |
196 | } |
197 | |
198 | IntegerAttr Builder::getI32IntegerAttr(int32_t value) { |
199 | // The APInt always uses isSigned=true here because we accept the value |
200 | // as int32_t. |
201 | return IntegerAttr::get(getIntegerType(32), |
202 | APInt(32, value, /*isSigned=*/true)); |
203 | } |
204 | |
205 | IntegerAttr Builder::getSI32IntegerAttr(int32_t value) { |
206 | return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true), |
207 | APInt(32, value, /*isSigned=*/true)); |
208 | } |
209 | |
210 | IntegerAttr Builder::getUI32IntegerAttr(uint32_t value) { |
211 | return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false), |
212 | APInt(32, (uint64_t)value, /*isSigned=*/false)); |
213 | } |
214 | |
215 | IntegerAttr Builder::getI16IntegerAttr(int16_t value) { |
216 | return IntegerAttr::get(getIntegerType(16), APInt(16, value)); |
217 | } |
218 | |
219 | IntegerAttr Builder::getI8IntegerAttr(int8_t value) { |
220 | // The APInt always uses isSigned=true here because we accept the value |
221 | // as int8_t. |
222 | return IntegerAttr::get(getIntegerType(8), |
223 | APInt(8, value, /*isSigned=*/true)); |
224 | } |
225 | |
226 | IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) { |
227 | if (type.isIndex()) |
228 | return IntegerAttr::get(type, APInt(64, value)); |
229 | // TODO: Avoid implicit trunc? |
230 | // See https://github.com/llvm/llvm-project/issues/112510. |
231 | return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value, |
232 | type.isSignedInteger(), |
233 | /*implicitTrunc=*/true)); |
234 | } |
235 | |
236 | IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) { |
237 | return IntegerAttr::get(type, value); |
238 | } |
239 | |
240 | FloatAttr Builder::getF64FloatAttr(double value) { |
241 | return FloatAttr::get(getF64Type(), APFloat(value)); |
242 | } |
243 | |
244 | FloatAttr Builder::getF32FloatAttr(float value) { |
245 | return FloatAttr::get(getF32Type(), APFloat(value)); |
246 | } |
247 | |
248 | FloatAttr Builder::getF16FloatAttr(float value) { |
249 | return FloatAttr::get(getF16Type(), value); |
250 | } |
251 | |
252 | FloatAttr Builder::getFloatAttr(Type type, double value) { |
253 | return FloatAttr::get(type, value); |
254 | } |
255 | |
256 | FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) { |
257 | return FloatAttr::get(type, value); |
258 | } |
259 | |
260 | StringAttr Builder::getStringAttr(const Twine &bytes) { |
261 | return StringAttr::get(context, bytes); |
262 | } |
263 | |
264 | ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) { |
265 | return ArrayAttr::get(context, value); |
266 | } |
267 | |
268 | ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) { |
269 | auto attrs = llvm::map_to_vector<8>( |
270 | C&: values, F: [this](bool v) -> Attribute { return getBoolAttr(value: v); }); |
271 | return getArrayAttr(attrs); |
272 | } |
273 | |
274 | ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) { |
275 | auto attrs = llvm::map_to_vector<8>( |
276 | C&: values, F: [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }); |
277 | return getArrayAttr(attrs); |
278 | } |
279 | ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) { |
280 | auto attrs = llvm::map_to_vector<8>( |
281 | C&: values, F: [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }); |
282 | return getArrayAttr(attrs); |
283 | } |
284 | |
285 | ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) { |
286 | auto attrs = llvm::map_to_vector<8>(C&: values, F: [this](int64_t v) -> Attribute { |
287 | return getIntegerAttr(IndexType::get(getContext()), v); |
288 | }); |
289 | return getArrayAttr(attrs); |
290 | } |
291 | |
292 | ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) { |
293 | auto attrs = llvm::map_to_vector<8>( |
294 | C&: values, F: [this](float v) -> Attribute { return getF32FloatAttr(v); }); |
295 | return getArrayAttr(attrs); |
296 | } |
297 | |
298 | ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) { |
299 | auto attrs = llvm::map_to_vector<8>( |
300 | C&: values, F: [this](double v) -> Attribute { return getF64FloatAttr(v); }); |
301 | return getArrayAttr(attrs); |
302 | } |
303 | |
304 | ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) { |
305 | auto attrs = llvm::map_to_vector<8>( |
306 | C&: values, F: [this](StringRef v) -> Attribute { return getStringAttr(v); }); |
307 | return getArrayAttr(attrs); |
308 | } |
309 | |
310 | ArrayAttr Builder::getTypeArrayAttr(TypeRange values) { |
311 | auto attrs = llvm::map_to_vector<8>( |
312 | C&: values, F: [](Type v) -> Attribute { return TypeAttr::get(v); }); |
313 | return getArrayAttr(attrs); |
314 | } |
315 | |
316 | ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) { |
317 | auto attrs = llvm::map_to_vector<8>( |
318 | C&: values, F: [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }); |
319 | return getArrayAttr(attrs); |
320 | } |
321 | |
322 | TypedAttr Builder::getZeroAttr(Type type) { |
323 | if (llvm::isa<FloatType>(Val: type)) |
324 | return getFloatAttr(type, 0.0); |
325 | if (llvm::isa<IndexType>(type)) |
326 | return getIndexAttr(0); |
327 | if (llvm::dyn_cast<IntegerType>(type)) |
328 | return getIntegerAttr(type, |
329 | APInt(llvm::cast<IntegerType>(type).getWidth(), 0)); |
330 | if (llvm::isa<RankedTensorType, VectorType>(Val: type)) { |
331 | auto vtType = llvm::cast<ShapedType>(type); |
332 | auto element = getZeroAttr(type: vtType.getElementType()); |
333 | if (!element) |
334 | return {}; |
335 | return DenseElementsAttr::get(vtType, element); |
336 | } |
337 | return {}; |
338 | } |
339 | |
340 | TypedAttr Builder::getOneAttr(Type type) { |
341 | if (llvm::isa<FloatType>(Val: type)) |
342 | return getFloatAttr(type, 1.0); |
343 | if (llvm::isa<IndexType>(type)) |
344 | return getIndexAttr(1); |
345 | if (llvm::dyn_cast<IntegerType>(type)) |
346 | return getIntegerAttr(type, |
347 | APInt(llvm::cast<IntegerType>(type).getWidth(), 1)); |
348 | if (llvm::isa<RankedTensorType, VectorType>(Val: type)) { |
349 | auto vtType = llvm::cast<ShapedType>(type); |
350 | auto element = getOneAttr(type: vtType.getElementType()); |
351 | if (!element) |
352 | return {}; |
353 | return DenseElementsAttr::get(vtType, element); |
354 | } |
355 | return {}; |
356 | } |
357 | |
358 | //===----------------------------------------------------------------------===// |
359 | // Affine Expressions, Affine Maps, and Integer Sets. |
360 | //===----------------------------------------------------------------------===// |
361 | |
362 | AffineExpr Builder::getAffineDimExpr(unsigned position) { |
363 | return mlir::getAffineDimExpr(position, context); |
364 | } |
365 | |
366 | AffineExpr Builder::getAffineSymbolExpr(unsigned position) { |
367 | return mlir::getAffineSymbolExpr(position, context); |
368 | } |
369 | |
370 | AffineExpr Builder::getAffineConstantExpr(int64_t constant) { |
371 | return mlir::getAffineConstantExpr(constant, context); |
372 | } |
373 | |
374 | AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); } |
375 | |
376 | AffineMap Builder::getConstantAffineMap(int64_t val) { |
377 | return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0, |
378 | result: getAffineConstantExpr(constant: val)); |
379 | } |
380 | |
381 | AffineMap Builder::getDimIdentityMap() { |
382 | return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, result: getAffineDimExpr(position: 0)); |
383 | } |
384 | |
385 | AffineMap Builder::getMultiDimIdentityMap(unsigned rank) { |
386 | SmallVector<AffineExpr, 4> dimExprs; |
387 | dimExprs.reserve(N: rank); |
388 | for (unsigned i = 0; i < rank; ++i) |
389 | dimExprs.push_back(Elt: getAffineDimExpr(position: i)); |
390 | return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, results: dimExprs, |
391 | context); |
392 | } |
393 | |
394 | AffineMap Builder::getSymbolIdentityMap() { |
395 | return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1, |
396 | result: getAffineSymbolExpr(position: 0)); |
397 | } |
398 | |
399 | AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) { |
400 | // expr = d0 + shift. |
401 | auto expr = getAffineDimExpr(position: 0) + shift; |
402 | return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, result: expr); |
403 | } |
404 | |
405 | AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) { |
406 | SmallVector<AffineExpr, 4> shiftedResults; |
407 | shiftedResults.reserve(N: map.getNumResults()); |
408 | for (auto resultExpr : map.getResults()) |
409 | shiftedResults.push_back(Elt: resultExpr + shift); |
410 | return AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: shiftedResults, |
411 | context); |
412 | } |
413 | |
414 | //===----------------------------------------------------------------------===// |
415 | // OpBuilder |
416 | //===----------------------------------------------------------------------===// |
417 | |
418 | /// Insert the given operation at the current insertion point and return it. |
419 | Operation *OpBuilder::insert(Operation *op) { |
420 | if (block) { |
421 | block->getOperations().insert(where: insertPoint, New: op); |
422 | if (listener) |
423 | listener->notifyOperationInserted(op, /*previous=*/{}); |
424 | } |
425 | return op; |
426 | } |
427 | |
428 | Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt, |
429 | TypeRange argTypes, ArrayRef<Location> locs) { |
430 | assert(parent && "expected valid parent region" ); |
431 | assert(argTypes.size() == locs.size() && "argument location mismatch" ); |
432 | if (insertPt == Region::iterator()) |
433 | insertPt = parent->end(); |
434 | |
435 | Block *b = new Block(); |
436 | b->addArguments(types: argTypes, locs); |
437 | parent->getBlocks().insert(where: insertPt, New: b); |
438 | setInsertionPointToEnd(b); |
439 | |
440 | if (listener) |
441 | listener->notifyBlockInserted(block: b, /*previous=*/nullptr, /*previousIt=*/{}); |
442 | return b; |
443 | } |
444 | |
445 | /// Add new block with 'argTypes' arguments and set the insertion point to the |
446 | /// end of it. The block is placed before 'insertBefore'. |
447 | Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes, |
448 | ArrayRef<Location> locs) { |
449 | assert(insertBefore && "expected valid insertion block" ); |
450 | return createBlock(parent: insertBefore->getParent(), insertPt: Region::iterator(insertBefore), |
451 | argTypes, locs); |
452 | } |
453 | |
454 | /// Create an operation given the fields represented as an OperationState. |
455 | Operation *OpBuilder::create(const OperationState &state) { |
456 | return insert(op: Operation::create(state)); |
457 | } |
458 | |
459 | /// Creates an operation with the given fields. |
460 | Operation *OpBuilder::create(Location loc, StringAttr opName, |
461 | ValueRange operands, TypeRange types, |
462 | ArrayRef<NamedAttribute> attributes, |
463 | BlockRange successors, |
464 | MutableArrayRef<std::unique_ptr<Region>> regions) { |
465 | OperationState state(loc, opName, operands, types, attributes, successors, |
466 | regions); |
467 | return create(state); |
468 | } |
469 | |
470 | LogicalResult |
471 | OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results, |
472 | SmallVectorImpl<Operation *> *materializedConstants) { |
473 | assert(results.empty() && "expected empty results" ); |
474 | ResultRange opResults = op->getResults(); |
475 | |
476 | results.reserve(N: opResults.size()); |
477 | auto cleanupFailure = [&] { |
478 | results.clear(); |
479 | return failure(); |
480 | }; |
481 | |
482 | // If this operation is already a constant, there is nothing to do. |
483 | if (matchPattern(op, pattern: m_Constant())) |
484 | return cleanupFailure(); |
485 | |
486 | // Try to fold the operation. |
487 | SmallVector<OpFoldResult, 4> foldResults; |
488 | if (failed(Result: op->fold(results&: foldResults))) |
489 | return cleanupFailure(); |
490 | |
491 | // An in-place fold does not require generation of any constants. |
492 | if (foldResults.empty()) |
493 | return success(); |
494 | |
495 | // A temporary builder used for creating constants during folding. |
496 | OpBuilder cstBuilder(context); |
497 | SmallVector<Operation *, 1> generatedConstants; |
498 | |
499 | // Populate the results with the folded results. |
500 | Dialect *dialect = op->getDialect(); |
501 | for (auto [foldResult, expectedType] : |
502 | llvm::zip_equal(t&: foldResults, u: opResults.getTypes())) { |
503 | |
504 | // Normal values get pushed back directly. |
505 | if (auto value = llvm::dyn_cast_if_present<Value>(Val&: foldResult)) { |
506 | results.push_back(Elt: value); |
507 | continue; |
508 | } |
509 | |
510 | // Otherwise, try to materialize a constant operation. |
511 | if (!dialect) |
512 | return cleanupFailure(); |
513 | |
514 | // Ask the dialect to materialize a constant operation for this value. |
515 | Attribute attr = cast<Attribute>(Val&: foldResult); |
516 | auto *constOp = dialect->materializeConstant(builder&: cstBuilder, value: attr, type: expectedType, |
517 | loc: op->getLoc()); |
518 | if (!constOp) { |
519 | // Erase any generated constants. |
520 | for (Operation *cst : generatedConstants) |
521 | cst->erase(); |
522 | return cleanupFailure(); |
523 | } |
524 | assert(matchPattern(constOp, m_Constant())); |
525 | |
526 | generatedConstants.push_back(Elt: constOp); |
527 | results.push_back(Elt: constOp->getResult(idx: 0)); |
528 | } |
529 | |
530 | // If we were successful, insert any generated constants. |
531 | for (Operation *cst : generatedConstants) |
532 | insert(op: cst); |
533 | |
534 | // Return materialized constant operations. |
535 | if (materializedConstants) |
536 | *materializedConstants = std::move(generatedConstants); |
537 | |
538 | return success(); |
539 | } |
540 | |
541 | /// Helper function that sends block insertion notifications for every block |
542 | /// that is directly nested in the given op. |
543 | static void notifyBlockInsertions(Operation *op, |
544 | OpBuilder::Listener *listener) { |
545 | for (Region &r : op->getRegions()) |
546 | for (Block &b : r.getBlocks()) |
547 | listener->notifyBlockInserted(block: &b, /*previous=*/nullptr, |
548 | /*previousIt=*/{}); |
549 | } |
550 | |
551 | Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) { |
552 | Operation *newOp = op.clone(mapper); |
553 | newOp = insert(op: newOp); |
554 | |
555 | // The `insert` call above handles the notification for inserting `newOp` |
556 | // itself. But if `newOp` has any regions, we need to notify the listener |
557 | // about any ops that got inserted inside those regions as part of cloning. |
558 | if (listener) { |
559 | // The `insert` call above notifies about op insertion, but not about block |
560 | // insertion. |
561 | notifyBlockInsertions(op: newOp, listener); |
562 | auto walkFn = [&](Operation *walkedOp) { |
563 | listener->notifyOperationInserted(op: walkedOp, /*previous=*/{}); |
564 | notifyBlockInsertions(op: walkedOp, listener); |
565 | }; |
566 | for (Region ®ion : newOp->getRegions()) |
567 | region.walk<WalkOrder::PreOrder>(callback&: walkFn); |
568 | } |
569 | |
570 | return newOp; |
571 | } |
572 | |
573 | Operation *OpBuilder::clone(Operation &op) { |
574 | IRMapping mapper; |
575 | return clone(op, mapper); |
576 | } |
577 | |
578 | void OpBuilder::cloneRegionBefore(Region ®ion, Region &parent, |
579 | Region::iterator before, IRMapping &mapping) { |
580 | region.cloneInto(dest: &parent, destPos: before, mapper&: mapping); |
581 | |
582 | // Fast path: If no listener is attached, there is no more work to do. |
583 | if (!listener) |
584 | return; |
585 | |
586 | // Notify about op/block insertion. |
587 | for (auto it = mapping.lookup(from: ®ion.front())->getIterator(); it != before; |
588 | ++it) { |
589 | listener->notifyBlockInserted(block: &*it, /*previous=*/nullptr, |
590 | /*previousIt=*/{}); |
591 | it->walk<WalkOrder::PreOrder>(callback: [&](Operation *walkedOp) { |
592 | listener->notifyOperationInserted(op: walkedOp, /*previous=*/{}); |
593 | notifyBlockInsertions(op: walkedOp, listener); |
594 | }); |
595 | } |
596 | } |
597 | |
598 | void OpBuilder::cloneRegionBefore(Region ®ion, Region &parent, |
599 | Region::iterator before) { |
600 | IRMapping mapping; |
601 | cloneRegionBefore(region, parent, before, mapping); |
602 | } |
603 | |
604 | void OpBuilder::cloneRegionBefore(Region ®ion, Block *before) { |
605 | cloneRegionBefore(region, parent&: *before->getParent(), before: before->getIterator()); |
606 | } |
607 | |