1//===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/Builders.h"
10#include "mlir/IR/AffineExpr.h"
11#include "mlir/IR/AffineMap.h"
12#include "mlir/IR/BuiltinTypes.h"
13#include "mlir/IR/Dialect.h"
14#include "mlir/IR/IRMapping.h"
15#include "mlir/IR/IntegerSet.h"
16#include "mlir/IR/Matchers.h"
17#include "mlir/IR/SymbolTable.h"
18#include "llvm/ADT/SmallVectorExtras.h"
19#include "llvm/Support/raw_ostream.h"
20
21using namespace mlir;
22
23//===----------------------------------------------------------------------===//
24// Locations.
25//===----------------------------------------------------------------------===//
26
27Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
28
29Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
30 return FusedLoc::get(locs, metadata, context);
31}
32
33//===----------------------------------------------------------------------===//
34// Types.
35//===----------------------------------------------------------------------===//
36
37FloatType Builder::getF8E8M0Type() { return Float8E8M0FNUType::get(context); }
38
39FloatType Builder::getBF16Type() { return BFloat16Type::get(context); }
40
41FloatType Builder::getF16Type() { return Float16Type::get(context); }
42
43FloatType Builder::getTF32Type() { return FloatTF32Type::get(context); }
44
45FloatType Builder::getF32Type() { return Float32Type::get(context); }
46
47FloatType Builder::getF64Type() { return Float64Type::get(context); }
48
49FloatType Builder::getF80Type() { return Float80Type::get(context); }
50
51FloatType Builder::getF128Type() { return Float128Type::get(context); }
52
53IndexType Builder::getIndexType() { return IndexType::get(context); }
54
55IntegerType Builder::getI1Type() { return IntegerType::get(context, 1); }
56
57IntegerType Builder::getI2Type() { return IntegerType::get(context, 2); }
58
59IntegerType Builder::getI4Type() { return IntegerType::get(context, 4); }
60
61IntegerType Builder::getI8Type() { return IntegerType::get(context, 8); }
62
63IntegerType Builder::getI16Type() { return IntegerType::get(context, 16); }
64
65IntegerType Builder::getI32Type() { return IntegerType::get(context, 32); }
66
67IntegerType Builder::getI64Type() { return IntegerType::get(context, 64); }
68
69IntegerType Builder::getIntegerType(unsigned width) {
70 return IntegerType::get(context, width);
71}
72
73IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
74 return IntegerType::get(
75 context, width, isSigned ? IntegerType::Signed : IntegerType::Unsigned);
76}
77
78FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) {
79 return FunctionType::get(context, inputs, results);
80}
81
82TupleType Builder::getTupleType(TypeRange elementTypes) {
83 return TupleType::get(context, elementTypes);
84}
85
86NoneType Builder::getNoneType() { return NoneType::get(context); }
87
88//===----------------------------------------------------------------------===//
89// Attributes.
90//===----------------------------------------------------------------------===//
91
92NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
93 return NamedAttribute(name, val);
94}
95
96UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
97
98BoolAttr Builder::getBoolAttr(bool value) {
99 return BoolAttr::get(context, value);
100}
101
102DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
103 return DictionaryAttr::get(context, value);
104}
105
106IntegerAttr Builder::getIndexAttr(int64_t value) {
107 return IntegerAttr::get(getIndexType(), APInt(64, value));
108}
109
110IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
111 return IntegerAttr::get(getIntegerType(64), APInt(64, value));
112}
113
114DenseIntElementsAttr Builder::getBoolVectorAttr(ArrayRef<bool> values) {
115 return DenseIntElementsAttr::get(
116 VectorType::get(static_cast<int64_t>(values.size()), getI1Type()),
117 values);
118}
119
120DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) {
121 return DenseIntElementsAttr::get(
122 VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(32)),
123 values);
124}
125
126DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef<int64_t> values) {
127 return DenseIntElementsAttr::get(
128 VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(64)),
129 values);
130}
131
132DenseIntElementsAttr Builder::getIndexVectorAttr(ArrayRef<int64_t> values) {
133 return DenseIntElementsAttr::get(
134 VectorType::get(static_cast<int64_t>(values.size()), getIndexType()),
135 values);
136}
137
138DenseFPElementsAttr Builder::getF32VectorAttr(ArrayRef<float> values) {
139 return DenseFPElementsAttr::get(
140 VectorType::get(static_cast<float>(values.size()), getF32Type()), values);
141}
142
143DenseFPElementsAttr Builder::getF64VectorAttr(ArrayRef<double> values) {
144 return DenseFPElementsAttr::get(
145 VectorType::get(static_cast<double>(values.size()), getF64Type()),
146 values);
147}
148
149DenseBoolArrayAttr Builder::getDenseBoolArrayAttr(ArrayRef<bool> values) {
150 return DenseBoolArrayAttr::get(context, values);
151}
152
153DenseI8ArrayAttr Builder::getDenseI8ArrayAttr(ArrayRef<int8_t> values) {
154 return DenseI8ArrayAttr::get(context, values);
155}
156
157DenseI16ArrayAttr Builder::getDenseI16ArrayAttr(ArrayRef<int16_t> values) {
158 return DenseI16ArrayAttr::get(context, values);
159}
160
161DenseI32ArrayAttr Builder::getDenseI32ArrayAttr(ArrayRef<int32_t> values) {
162 return DenseI32ArrayAttr::get(context, values);
163}
164
165DenseI64ArrayAttr Builder::getDenseI64ArrayAttr(ArrayRef<int64_t> values) {
166 return DenseI64ArrayAttr::get(context, values);
167}
168
169DenseF32ArrayAttr Builder::getDenseF32ArrayAttr(ArrayRef<float> values) {
170 return DenseF32ArrayAttr::get(context, values);
171}
172
173DenseF64ArrayAttr Builder::getDenseF64ArrayAttr(ArrayRef<double> values) {
174 return DenseF64ArrayAttr::get(context, values);
175}
176
177DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef<int32_t> values) {
178 return DenseIntElementsAttr::get(
179 RankedTensorType::get(static_cast<int64_t>(values.size()),
180 getIntegerType(32)),
181 values);
182}
183
184DenseIntElementsAttr Builder::getI64TensorAttr(ArrayRef<int64_t> values) {
185 return DenseIntElementsAttr::get(
186 RankedTensorType::get(static_cast<int64_t>(values.size()),
187 getIntegerType(64)),
188 values);
189}
190
191DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
192 return DenseIntElementsAttr::get(
193 RankedTensorType::get(static_cast<int64_t>(values.size()),
194 getIndexType()),
195 values);
196}
197
198IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
199 // The APInt always uses isSigned=true here because we accept the value
200 // as int32_t.
201 return IntegerAttr::get(getIntegerType(32),
202 APInt(32, value, /*isSigned=*/true));
203}
204
205IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
206 return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true),
207 APInt(32, value, /*isSigned=*/true));
208}
209
210IntegerAttr Builder::getUI32IntegerAttr(uint32_t value) {
211 return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false),
212 APInt(32, (uint64_t)value, /*isSigned=*/false));
213}
214
215IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
216 return IntegerAttr::get(getIntegerType(16), APInt(16, value));
217}
218
219IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
220 // The APInt always uses isSigned=true here because we accept the value
221 // as int8_t.
222 return IntegerAttr::get(getIntegerType(8),
223 APInt(8, value, /*isSigned=*/true));
224}
225
226IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
227 if (type.isIndex())
228 return IntegerAttr::get(type, APInt(64, value));
229 // TODO: Avoid implicit trunc?
230 // See https://github.com/llvm/llvm-project/issues/112510.
231 return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value,
232 type.isSignedInteger(),
233 /*implicitTrunc=*/true));
234}
235
236IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
237 return IntegerAttr::get(type, value);
238}
239
240FloatAttr Builder::getF64FloatAttr(double value) {
241 return FloatAttr::get(getF64Type(), APFloat(value));
242}
243
244FloatAttr Builder::getF32FloatAttr(float value) {
245 return FloatAttr::get(getF32Type(), APFloat(value));
246}
247
248FloatAttr Builder::getF16FloatAttr(float value) {
249 return FloatAttr::get(getF16Type(), value);
250}
251
252FloatAttr Builder::getFloatAttr(Type type, double value) {
253 return FloatAttr::get(type, value);
254}
255
256FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
257 return FloatAttr::get(type, value);
258}
259
260StringAttr Builder::getStringAttr(const Twine &bytes) {
261 return StringAttr::get(context, bytes);
262}
263
264ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
265 return ArrayAttr::get(context, value);
266}
267
268ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
269 auto attrs = llvm::map_to_vector<8>(
270 C&: values, F: [this](bool v) -> Attribute { return getBoolAttr(value: v); });
271 return getArrayAttr(attrs);
272}
273
274ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
275 auto attrs = llvm::map_to_vector<8>(
276 C&: values, F: [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); });
277 return getArrayAttr(attrs);
278}
279ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
280 auto attrs = llvm::map_to_vector<8>(
281 C&: values, F: [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); });
282 return getArrayAttr(attrs);
283}
284
285ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
286 auto attrs = llvm::map_to_vector<8>(C&: values, F: [this](int64_t v) -> Attribute {
287 return getIntegerAttr(IndexType::get(getContext()), v);
288 });
289 return getArrayAttr(attrs);
290}
291
292ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
293 auto attrs = llvm::map_to_vector<8>(
294 C&: values, F: [this](float v) -> Attribute { return getF32FloatAttr(v); });
295 return getArrayAttr(attrs);
296}
297
298ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
299 auto attrs = llvm::map_to_vector<8>(
300 C&: values, F: [this](double v) -> Attribute { return getF64FloatAttr(v); });
301 return getArrayAttr(attrs);
302}
303
304ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
305 auto attrs = llvm::map_to_vector<8>(
306 C&: values, F: [this](StringRef v) -> Attribute { return getStringAttr(v); });
307 return getArrayAttr(attrs);
308}
309
310ArrayAttr Builder::getTypeArrayAttr(TypeRange values) {
311 auto attrs = llvm::map_to_vector<8>(
312 C&: values, F: [](Type v) -> Attribute { return TypeAttr::get(v); });
313 return getArrayAttr(attrs);
314}
315
316ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
317 auto attrs = llvm::map_to_vector<8>(
318 C&: values, F: [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); });
319 return getArrayAttr(attrs);
320}
321
322TypedAttr Builder::getZeroAttr(Type type) {
323 if (llvm::isa<FloatType>(Val: type))
324 return getFloatAttr(type, 0.0);
325 if (llvm::isa<IndexType>(type))
326 return getIndexAttr(0);
327 if (llvm::dyn_cast<IntegerType>(type))
328 return getIntegerAttr(type,
329 APInt(llvm::cast<IntegerType>(type).getWidth(), 0));
330 if (llvm::isa<RankedTensorType, VectorType>(Val: type)) {
331 auto vtType = llvm::cast<ShapedType>(type);
332 auto element = getZeroAttr(type: vtType.getElementType());
333 if (!element)
334 return {};
335 return DenseElementsAttr::get(vtType, element);
336 }
337 return {};
338}
339
340TypedAttr Builder::getOneAttr(Type type) {
341 if (llvm::isa<FloatType>(Val: type))
342 return getFloatAttr(type, 1.0);
343 if (llvm::isa<IndexType>(type))
344 return getIndexAttr(1);
345 if (llvm::dyn_cast<IntegerType>(type))
346 return getIntegerAttr(type,
347 APInt(llvm::cast<IntegerType>(type).getWidth(), 1));
348 if (llvm::isa<RankedTensorType, VectorType>(Val: type)) {
349 auto vtType = llvm::cast<ShapedType>(type);
350 auto element = getOneAttr(type: vtType.getElementType());
351 if (!element)
352 return {};
353 return DenseElementsAttr::get(vtType, element);
354 }
355 return {};
356}
357
358//===----------------------------------------------------------------------===//
359// Affine Expressions, Affine Maps, and Integer Sets.
360//===----------------------------------------------------------------------===//
361
362AffineExpr Builder::getAffineDimExpr(unsigned position) {
363 return mlir::getAffineDimExpr(position, context);
364}
365
366AffineExpr Builder::getAffineSymbolExpr(unsigned position) {
367 return mlir::getAffineSymbolExpr(position, context);
368}
369
370AffineExpr Builder::getAffineConstantExpr(int64_t constant) {
371 return mlir::getAffineConstantExpr(constant, context);
372}
373
374AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); }
375
376AffineMap Builder::getConstantAffineMap(int64_t val) {
377 return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
378 result: getAffineConstantExpr(constant: val));
379}
380
381AffineMap Builder::getDimIdentityMap() {
382 return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, result: getAffineDimExpr(position: 0));
383}
384
385AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
386 SmallVector<AffineExpr, 4> dimExprs;
387 dimExprs.reserve(N: rank);
388 for (unsigned i = 0; i < rank; ++i)
389 dimExprs.push_back(Elt: getAffineDimExpr(position: i));
390 return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, results: dimExprs,
391 context);
392}
393
394AffineMap Builder::getSymbolIdentityMap() {
395 return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
396 result: getAffineSymbolExpr(position: 0));
397}
398
399AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
400 // expr = d0 + shift.
401 auto expr = getAffineDimExpr(position: 0) + shift;
402 return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, result: expr);
403}
404
405AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
406 SmallVector<AffineExpr, 4> shiftedResults;
407 shiftedResults.reserve(N: map.getNumResults());
408 for (auto resultExpr : map.getResults())
409 shiftedResults.push_back(Elt: resultExpr + shift);
410 return AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: shiftedResults,
411 context);
412}
413
414//===----------------------------------------------------------------------===//
415// OpBuilder
416//===----------------------------------------------------------------------===//
417
418/// Insert the given operation at the current insertion point and return it.
419Operation *OpBuilder::insert(Operation *op) {
420 if (block) {
421 block->getOperations().insert(where: insertPoint, New: op);
422 if (listener)
423 listener->notifyOperationInserted(op, /*previous=*/{});
424 }
425 return op;
426}
427
428Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
429 TypeRange argTypes, ArrayRef<Location> locs) {
430 assert(parent && "expected valid parent region");
431 assert(argTypes.size() == locs.size() && "argument location mismatch");
432 if (insertPt == Region::iterator())
433 insertPt = parent->end();
434
435 Block *b = new Block();
436 b->addArguments(types: argTypes, locs);
437 parent->getBlocks().insert(where: insertPt, New: b);
438 setInsertionPointToEnd(b);
439
440 if (listener)
441 listener->notifyBlockInserted(block: b, /*previous=*/nullptr, /*previousIt=*/{});
442 return b;
443}
444
445/// Add new block with 'argTypes' arguments and set the insertion point to the
446/// end of it. The block is placed before 'insertBefore'.
447Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes,
448 ArrayRef<Location> locs) {
449 assert(insertBefore && "expected valid insertion block");
450 return createBlock(parent: insertBefore->getParent(), insertPt: Region::iterator(insertBefore),
451 argTypes, locs);
452}
453
454/// Create an operation given the fields represented as an OperationState.
455Operation *OpBuilder::create(const OperationState &state) {
456 return insert(op: Operation::create(state));
457}
458
459/// Creates an operation with the given fields.
460Operation *OpBuilder::create(Location loc, StringAttr opName,
461 ValueRange operands, TypeRange types,
462 ArrayRef<NamedAttribute> attributes,
463 BlockRange successors,
464 MutableArrayRef<std::unique_ptr<Region>> regions) {
465 OperationState state(loc, opName, operands, types, attributes, successors,
466 regions);
467 return create(state);
468}
469
470LogicalResult
471OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results,
472 SmallVectorImpl<Operation *> *materializedConstants) {
473 assert(results.empty() && "expected empty results");
474 ResultRange opResults = op->getResults();
475
476 results.reserve(N: opResults.size());
477 auto cleanupFailure = [&] {
478 results.clear();
479 return failure();
480 };
481
482 // If this operation is already a constant, there is nothing to do.
483 if (matchPattern(op, pattern: m_Constant()))
484 return cleanupFailure();
485
486 // Try to fold the operation.
487 SmallVector<OpFoldResult, 4> foldResults;
488 if (failed(Result: op->fold(results&: foldResults)))
489 return cleanupFailure();
490
491 // An in-place fold does not require generation of any constants.
492 if (foldResults.empty())
493 return success();
494
495 // A temporary builder used for creating constants during folding.
496 OpBuilder cstBuilder(context);
497 SmallVector<Operation *, 1> generatedConstants;
498
499 // Populate the results with the folded results.
500 Dialect *dialect = op->getDialect();
501 for (auto [foldResult, expectedType] :
502 llvm::zip_equal(t&: foldResults, u: opResults.getTypes())) {
503
504 // Normal values get pushed back directly.
505 if (auto value = llvm::dyn_cast_if_present<Value>(Val&: foldResult)) {
506 results.push_back(Elt: value);
507 continue;
508 }
509
510 // Otherwise, try to materialize a constant operation.
511 if (!dialect)
512 return cleanupFailure();
513
514 // Ask the dialect to materialize a constant operation for this value.
515 Attribute attr = cast<Attribute>(Val&: foldResult);
516 auto *constOp = dialect->materializeConstant(builder&: cstBuilder, value: attr, type: expectedType,
517 loc: op->getLoc());
518 if (!constOp) {
519 // Erase any generated constants.
520 for (Operation *cst : generatedConstants)
521 cst->erase();
522 return cleanupFailure();
523 }
524 assert(matchPattern(constOp, m_Constant()));
525
526 generatedConstants.push_back(Elt: constOp);
527 results.push_back(Elt: constOp->getResult(idx: 0));
528 }
529
530 // If we were successful, insert any generated constants.
531 for (Operation *cst : generatedConstants)
532 insert(op: cst);
533
534 // Return materialized constant operations.
535 if (materializedConstants)
536 *materializedConstants = std::move(generatedConstants);
537
538 return success();
539}
540
541/// Helper function that sends block insertion notifications for every block
542/// that is directly nested in the given op.
543static void notifyBlockInsertions(Operation *op,
544 OpBuilder::Listener *listener) {
545 for (Region &r : op->getRegions())
546 for (Block &b : r.getBlocks())
547 listener->notifyBlockInserted(block: &b, /*previous=*/nullptr,
548 /*previousIt=*/{});
549}
550
551Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
552 Operation *newOp = op.clone(mapper);
553 newOp = insert(op: newOp);
554
555 // The `insert` call above handles the notification for inserting `newOp`
556 // itself. But if `newOp` has any regions, we need to notify the listener
557 // about any ops that got inserted inside those regions as part of cloning.
558 if (listener) {
559 // The `insert` call above notifies about op insertion, but not about block
560 // insertion.
561 notifyBlockInsertions(op: newOp, listener);
562 auto walkFn = [&](Operation *walkedOp) {
563 listener->notifyOperationInserted(op: walkedOp, /*previous=*/{});
564 notifyBlockInsertions(op: walkedOp, listener);
565 };
566 for (Region &region : newOp->getRegions())
567 region.walk<WalkOrder::PreOrder>(callback&: walkFn);
568 }
569
570 return newOp;
571}
572
573Operation *OpBuilder::clone(Operation &op) {
574 IRMapping mapper;
575 return clone(op, mapper);
576}
577
578void OpBuilder::cloneRegionBefore(Region &region, Region &parent,
579 Region::iterator before, IRMapping &mapping) {
580 region.cloneInto(dest: &parent, destPos: before, mapper&: mapping);
581
582 // Fast path: If no listener is attached, there is no more work to do.
583 if (!listener)
584 return;
585
586 // Notify about op/block insertion.
587 for (auto it = mapping.lookup(from: &region.front())->getIterator(); it != before;
588 ++it) {
589 listener->notifyBlockInserted(block: &*it, /*previous=*/nullptr,
590 /*previousIt=*/{});
591 it->walk<WalkOrder::PreOrder>(callback: [&](Operation *walkedOp) {
592 listener->notifyOperationInserted(op: walkedOp, /*previous=*/{});
593 notifyBlockInsertions(op: walkedOp, listener);
594 });
595 }
596}
597
598void OpBuilder::cloneRegionBefore(Region &region, Region &parent,
599 Region::iterator before) {
600 IRMapping mapping;
601 cloneRegionBefore(region, parent, before, mapping);
602}
603
604void OpBuilder::cloneRegionBefore(Region &region, Block *before) {
605 cloneRegionBefore(region, parent&: *before->getParent(), before: before->getIterator());
606}
607

source code of mlir/lib/IR/Builders.cpp