1//===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/BuiltinTypes.h"
10#include "TypeDetail.h"
11#include "mlir/IR/AffineExpr.h"
12#include "mlir/IR/AffineMap.h"
13#include "mlir/IR/BuiltinAttributes.h"
14#include "mlir/IR/BuiltinDialect.h"
15#include "mlir/IR/Diagnostics.h"
16#include "mlir/IR/Dialect.h"
17#include "mlir/IR/TensorEncoding.h"
18#include "mlir/IR/TypeUtilities.h"
19#include "llvm/ADT/APFloat.h"
20#include "llvm/ADT/BitVector.h"
21#include "llvm/ADT/Sequence.h"
22#include "llvm/ADT/Twine.h"
23#include "llvm/ADT/TypeSwitch.h"
24
25using namespace mlir;
26using namespace mlir::detail;
27
28//===----------------------------------------------------------------------===//
29/// Tablegen Type Definitions
30//===----------------------------------------------------------------------===//
31
32#define GET_TYPEDEF_CLASSES
33#include "mlir/IR/BuiltinTypes.cpp.inc"
34
35namespace mlir {
36#include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
37} // namespace mlir
38
39//===----------------------------------------------------------------------===//
40// BuiltinDialect
41//===----------------------------------------------------------------------===//
42
43void BuiltinDialect::registerTypes() {
44 addTypes<
45#define GET_TYPEDEF_LIST
46#include "mlir/IR/BuiltinTypes.cpp.inc"
47 >();
48}
49
50//===----------------------------------------------------------------------===//
51/// ComplexType
52//===----------------------------------------------------------------------===//
53
54/// Verify the construction of an integer type.
55LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
56 Type elementType) {
57 if (!elementType.isIntOrFloat())
58 return emitError() << "invalid element type for complex";
59 return success();
60}
61
62//===----------------------------------------------------------------------===//
63// Integer Type
64//===----------------------------------------------------------------------===//
65
66/// Verify the construction of an integer type.
67LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
68 unsigned width,
69 SignednessSemantics signedness) {
70 if (width > IntegerType::kMaxWidth) {
71 return emitError() << "integer bitwidth is limited to "
72 << IntegerType::kMaxWidth << " bits";
73 }
74 return success();
75}
76
77unsigned IntegerType::getWidth() const { return getImpl()->width; }
78
79IntegerType::SignednessSemantics IntegerType::getSignedness() const {
80 return getImpl()->signedness;
81}
82
83IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
84 if (!scale)
85 return IntegerType();
86 return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
87}
88
89//===----------------------------------------------------------------------===//
90// Float Types
91//===----------------------------------------------------------------------===//
92
93// Mapping from MLIR FloatType to APFloat semantics.
94#define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \
95 const llvm::fltSemantics &TYPE::getFloatSemantics() const { \
96 return APFloat::SEM(); \
97 }
98FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN)
99FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN)
100FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN)
101FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2)
102FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3)
103FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN)
104FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ)
105FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ)
106FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ)
107FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4)
108FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU)
109FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat)
110FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf)
111FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32)
112FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle)
113FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble)
114FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended)
115FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad)
116#undef FLOAT_TYPE_SEMANTICS
117
118FloatType Float16Type::scaleElementBitwidth(unsigned scale) const {
119 if (scale == 2)
120 return Float32Type::get(getContext());
121 if (scale == 4)
122 return Float64Type::get(getContext());
123 return FloatType();
124}
125
126FloatType BFloat16Type::scaleElementBitwidth(unsigned scale) const {
127 if (scale == 2)
128 return Float32Type::get(getContext());
129 if (scale == 4)
130 return Float64Type::get(getContext());
131 return FloatType();
132}
133
134FloatType Float32Type::scaleElementBitwidth(unsigned scale) const {
135 if (scale == 2)
136 return Float64Type::get(getContext());
137 return FloatType();
138}
139
140//===----------------------------------------------------------------------===//
141// FunctionType
142//===----------------------------------------------------------------------===//
143
144unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
145
146ArrayRef<Type> FunctionType::getInputs() const {
147 return getImpl()->getInputs();
148}
149
150unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
151
152ArrayRef<Type> FunctionType::getResults() const {
153 return getImpl()->getResults();
154}
155
156FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
157 return get(getContext(), inputs, results);
158}
159
160/// Returns a new function type with the specified arguments and results
161/// inserted.
162FunctionType FunctionType::getWithArgsAndResults(
163 ArrayRef<unsigned> argIndices, TypeRange argTypes,
164 ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
165 SmallVector<Type> argStorage, resultStorage;
166 TypeRange newArgTypes =
167 insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
168 TypeRange newResultTypes =
169 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
170 return clone(newArgTypes, newResultTypes);
171}
172
173/// Returns a new function type without the specified arguments and results.
174FunctionType
175FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
176 const BitVector &resultIndices) {
177 SmallVector<Type> argStorage, resultStorage;
178 TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
179 TypeRange newResultTypes =
180 filterTypesOut(getResults(), resultIndices, resultStorage);
181 return clone(newArgTypes, newResultTypes);
182}
183
184//===----------------------------------------------------------------------===//
185// OpaqueType
186//===----------------------------------------------------------------------===//
187
188/// Verify the construction of an opaque type.
189LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
190 StringAttr dialect, StringRef typeData) {
191 if (!Dialect::isValidNamespace(dialect.strref()))
192 return emitError() << "invalid dialect namespace '" << dialect << "'";
193
194 // Check that the dialect is actually registered.
195 MLIRContext *context = dialect.getContext();
196 if (!context->allowsUnregisteredDialects() &&
197 !context->getLoadedDialect(dialect.strref())) {
198 return emitError()
199 << "`!" << dialect << "<\"" << typeData << "\">"
200 << "` type created with unregistered dialect. If this is "
201 "intended, please call allowUnregisteredDialects() on the "
202 "MLIRContext, or use -allow-unregistered-dialect with "
203 "the MLIR opt tool used";
204 }
205
206 return success();
207}
208
209//===----------------------------------------------------------------------===//
210// VectorType
211//===----------------------------------------------------------------------===//
212
213bool VectorType::isValidElementType(Type t) {
214 return isValidVectorTypeElementType(t);
215}
216
217LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
218 ArrayRef<int64_t> shape, Type elementType,
219 ArrayRef<bool> scalableDims) {
220 if (!isValidElementType(elementType))
221 return emitError()
222 << "vector elements must be int/index/float type but got "
223 << elementType;
224
225 if (any_of(shape, [](int64_t i) { return i <= 0; }))
226 return emitError()
227 << "vector types must have positive constant sizes but got "
228 << shape;
229
230 if (scalableDims.size() != shape.size())
231 return emitError() << "number of dims must match, got "
232 << scalableDims.size() << " and " << shape.size();
233
234 return success();
235}
236
237VectorType VectorType::scaleElementBitwidth(unsigned scale) {
238 if (!scale)
239 return VectorType();
240 if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
241 if (auto scaledEt = et.scaleElementBitwidth(scale))
242 return VectorType::get(getShape(), scaledEt, getScalableDims());
243 if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
244 if (auto scaledEt = et.scaleElementBitwidth(scale))
245 return VectorType::get(getShape(), scaledEt, getScalableDims());
246 return VectorType();
247}
248
249VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
250 Type elementType) const {
251 return VectorType::get(shape.value_or(getShape()), elementType,
252 getScalableDims());
253}
254
255//===----------------------------------------------------------------------===//
256// TensorType
257//===----------------------------------------------------------------------===//
258
259Type TensorType::getElementType() const {
260 return llvm::TypeSwitch<TensorType, Type>(*this)
261 .Case<RankedTensorType, UnrankedTensorType>(
262 [](auto type) { return type.getElementType(); });
263}
264
265bool TensorType::hasRank() const {
266 return !llvm::isa<UnrankedTensorType>(Val: *this);
267}
268
269ArrayRef<int64_t> TensorType::getShape() const {
270 return llvm::cast<RankedTensorType>(*this).getShape();
271}
272
273TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
274 Type elementType) const {
275 if (llvm::dyn_cast<UnrankedTensorType>(*this)) {
276 if (shape)
277 return RankedTensorType::get(*shape, elementType);
278 return UnrankedTensorType::get(elementType);
279 }
280
281 auto rankedTy = llvm::cast<RankedTensorType>(*this);
282 if (!shape)
283 return RankedTensorType::get(rankedTy.getShape(), elementType,
284 rankedTy.getEncoding());
285 return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
286 rankedTy.getEncoding());
287}
288
289RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
290 Type elementType) const {
291 return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
292}
293
294RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
295 return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType: getElementType()));
296}
297
298// Check if "elementType" can be an element type of a tensor.
299static LogicalResult
300checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
301 Type elementType) {
302 if (!TensorType::isValidElementType(type: elementType))
303 return emitError() << "invalid tensor element type: " << elementType;
304 return success();
305}
306
307/// Return true if the specified element type is ok in a tensor.
308bool TensorType::isValidElementType(Type type) {
309 // Note: Non standard/builtin types are allowed to exist within tensor
310 // types. Dialects are expected to verify that tensor types have a valid
311 // element type within that dialect.
312 return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
313 IndexType>(type) ||
314 !llvm::isa<BuiltinDialect>(type.getDialect());
315}
316
317//===----------------------------------------------------------------------===//
318// RankedTensorType
319//===----------------------------------------------------------------------===//
320
321LogicalResult
322RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
323 ArrayRef<int64_t> shape, Type elementType,
324 Attribute encoding) {
325 for (int64_t s : shape)
326 if (s < 0 && !ShapedType::isDynamic(s))
327 return emitError() << "invalid tensor dimension size";
328 if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
329 if (failed(v.verifyEncoding(shape, elementType, emitError)))
330 return failure();
331 return checkTensorElementType(emitError, elementType);
332}
333
334//===----------------------------------------------------------------------===//
335// UnrankedTensorType
336//===----------------------------------------------------------------------===//
337
338LogicalResult
339UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
340 Type elementType) {
341 return checkTensorElementType(emitError, elementType);
342}
343
344//===----------------------------------------------------------------------===//
345// BaseMemRefType
346//===----------------------------------------------------------------------===//
347
348Type BaseMemRefType::getElementType() const {
349 return llvm::TypeSwitch<BaseMemRefType, Type>(*this)
350 .Case<MemRefType, UnrankedMemRefType>(
351 [](auto type) { return type.getElementType(); });
352}
353
354bool BaseMemRefType::hasRank() const {
355 return !llvm::isa<UnrankedMemRefType>(*this);
356}
357
358ArrayRef<int64_t> BaseMemRefType::getShape() const {
359 return llvm::cast<MemRefType>(*this).getShape();
360}
361
362BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
363 Type elementType) const {
364 if (llvm::dyn_cast<UnrankedMemRefType>(*this)) {
365 if (!shape)
366 return UnrankedMemRefType::get(elementType, getMemorySpace());
367 MemRefType::Builder builder(*shape, elementType);
368 builder.setMemorySpace(getMemorySpace());
369 return builder;
370 }
371
372 MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
373 if (shape)
374 builder.setShape(*shape);
375 builder.setElementType(elementType);
376 return builder;
377}
378
379MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape,
380 Type elementType) const {
381 return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
382}
383
384MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
385 return ::llvm::cast<MemRefType>(cloneWith(shape, elementType: getElementType()));
386}
387
388Attribute BaseMemRefType::getMemorySpace() const {
389 if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
390 return rankedMemRefTy.getMemorySpace();
391 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
392}
393
394unsigned BaseMemRefType::getMemorySpaceAsInt() const {
395 if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
396 return rankedMemRefTy.getMemorySpaceAsInt();
397 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
398}
399
400//===----------------------------------------------------------------------===//
401// MemRefType
402//===----------------------------------------------------------------------===//
403
404std::optional<llvm::SmallDenseSet<unsigned>>
405mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
406 ArrayRef<int64_t> reducedShape,
407 bool matchDynamic) {
408 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
409 llvm::SmallDenseSet<unsigned> unusedDims;
410 unsigned reducedIdx = 0;
411 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
412 // Greedily insert `originalIdx` if match.
413 int64_t origSize = originalShape[originalIdx];
414 // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
415 if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
416 (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
417 ShapedType::isDynamic(origSize))) {
418 reducedIdx++;
419 continue;
420 }
421 if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
422 reducedIdx++;
423 continue;
424 }
425
426 unusedDims.insert(V: originalIdx);
427 // If no match on `originalIdx`, the `originalShape` at this dimension
428 // must be 1, otherwise we bail.
429 if (origSize != 1)
430 return std::nullopt;
431 }
432 // The whole reducedShape must be scanned, otherwise we bail.
433 if (reducedIdx != reducedRank)
434 return std::nullopt;
435 return unusedDims;
436}
437
438SliceVerificationResult
439mlir::isRankReducedType(ShapedType originalType,
440 ShapedType candidateReducedType) {
441 if (originalType == candidateReducedType)
442 return SliceVerificationResult::Success;
443
444 ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
445 ShapedType candidateReducedShapedType =
446 llvm::cast<ShapedType>(candidateReducedType);
447
448 // Rank and size logic is valid for all ShapedTypes.
449 ArrayRef<int64_t> originalShape = originalShapedType.getShape();
450 ArrayRef<int64_t> candidateReducedShape =
451 candidateReducedShapedType.getShape();
452 unsigned originalRank = originalShape.size(),
453 candidateReducedRank = candidateReducedShape.size();
454 if (candidateReducedRank > originalRank)
455 return SliceVerificationResult::RankTooLarge;
456
457 auto optionalUnusedDimsMask =
458 computeRankReductionMask(originalShape, candidateReducedShape);
459
460 // Sizes cannot be matched in case empty vector is returned.
461 if (!optionalUnusedDimsMask)
462 return SliceVerificationResult::SizeMismatch;
463
464 if (originalShapedType.getElementType() !=
465 candidateReducedShapedType.getElementType())
466 return SliceVerificationResult::ElemTypeMismatch;
467
468 return SliceVerificationResult::Success;
469}
470
471bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
472 // Empty attribute is allowed as default memory space.
473 if (!memorySpace)
474 return true;
475
476 // Supported built-in attributes.
477 if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
478 return true;
479
480 // Allow custom dialect attributes.
481 if (!isa<BuiltinDialect>(Val: memorySpace.getDialect()))
482 return true;
483
484 return false;
485}
486
487Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
488 MLIRContext *ctx) {
489 if (memorySpace == 0)
490 return nullptr;
491
492 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
493}
494
495Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
496 IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
497 if (intMemorySpace && intMemorySpace.getValue() == 0)
498 return nullptr;
499
500 return memorySpace;
501}
502
503unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
504 if (!memorySpace)
505 return 0;
506
507 assert(llvm::isa<IntegerAttr>(memorySpace) &&
508 "Using `getMemorySpaceInteger` with non-Integer attribute");
509
510 return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
511}
512
513unsigned MemRefType::getMemorySpaceAsInt() const {
514 return detail::getMemorySpaceAsInt(getMemorySpace());
515}
516
517MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
518 MemRefLayoutAttrInterface layout,
519 Attribute memorySpace) {
520 // Use default layout for empty attribute.
521 if (!layout)
522 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
523 shape.size(), elementType.getContext()));
524
525 // Drop default memory space value and replace it with empty attribute.
526 memorySpace = skipDefaultMemorySpace(memorySpace);
527
528 return Base::get(elementType.getContext(), shape, elementType, layout,
529 memorySpace);
530}
531
532MemRefType MemRefType::getChecked(
533 function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
534 Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
535
536 // Use default layout for empty attribute.
537 if (!layout)
538 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
539 shape.size(), elementType.getContext()));
540
541 // Drop default memory space value and replace it with empty attribute.
542 memorySpace = skipDefaultMemorySpace(memorySpace);
543
544 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
545 elementType, layout, memorySpace);
546}
547
548MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
549 AffineMap map, Attribute memorySpace) {
550
551 // Use default layout for empty map.
552 if (!map)
553 map = AffineMap::getMultiDimIdentityMap(shape.size(),
554 elementType.getContext());
555
556 // Wrap AffineMap into Attribute.
557 auto layout = AffineMapAttr::get(map);
558
559 // Drop default memory space value and replace it with empty attribute.
560 memorySpace = skipDefaultMemorySpace(memorySpace);
561
562 return Base::get(elementType.getContext(), shape, elementType, layout,
563 memorySpace);
564}
565
566MemRefType
567MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
568 ArrayRef<int64_t> shape, Type elementType, AffineMap map,
569 Attribute memorySpace) {
570
571 // Use default layout for empty map.
572 if (!map)
573 map = AffineMap::getMultiDimIdentityMap(shape.size(),
574 elementType.getContext());
575
576 // Wrap AffineMap into Attribute.
577 auto layout = AffineMapAttr::get(map);
578
579 // Drop default memory space value and replace it with empty attribute.
580 memorySpace = skipDefaultMemorySpace(memorySpace);
581
582 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
583 elementType, layout, memorySpace);
584}
585
586MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
587 AffineMap map, unsigned memorySpaceInd) {
588
589 // Use default layout for empty map.
590 if (!map)
591 map = AffineMap::getMultiDimIdentityMap(shape.size(),
592 elementType.getContext());
593
594 // Wrap AffineMap into Attribute.
595 auto layout = AffineMapAttr::get(map);
596
597 // Convert deprecated integer-like memory space to Attribute.
598 Attribute memorySpace =
599 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
600
601 return Base::get(elementType.getContext(), shape, elementType, layout,
602 memorySpace);
603}
604
605MemRefType
606MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
607 ArrayRef<int64_t> shape, Type elementType, AffineMap map,
608 unsigned memorySpaceInd) {
609
610 // Use default layout for empty map.
611 if (!map)
612 map = AffineMap::getMultiDimIdentityMap(shape.size(),
613 elementType.getContext());
614
615 // Wrap AffineMap into Attribute.
616 auto layout = AffineMapAttr::get(map);
617
618 // Convert deprecated integer-like memory space to Attribute.
619 Attribute memorySpace =
620 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
621
622 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
623 elementType, layout, memorySpace);
624}
625
626LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
627 ArrayRef<int64_t> shape, Type elementType,
628 MemRefLayoutAttrInterface layout,
629 Attribute memorySpace) {
630 if (!BaseMemRefType::isValidElementType(elementType))
631 return emitError() << "invalid memref element type";
632
633 // Negative sizes are not allowed except for `kDynamic`.
634 for (int64_t s : shape)
635 if (s < 0 && !ShapedType::isDynamic(s))
636 return emitError() << "invalid memref size";
637
638 assert(layout && "missing layout specification");
639 if (failed(layout.verifyLayout(shape, emitError)))
640 return failure();
641
642 if (!isSupportedMemorySpace(memorySpace))
643 return emitError() << "unsupported memory space Attribute";
644
645 return success();
646}
647
648bool MemRefType::areTrailingDimsContiguous(int64_t n) {
649 if (!isLastDimUnitStride())
650 return false;
651
652 auto memrefShape = getShape().take_back(n);
653 if (ShapedType::isDynamicShape(memrefShape))
654 return false;
655
656 if (getLayout().isIdentity())
657 return true;
658
659 int64_t offset;
660 SmallVector<int64_t> stridesFull;
661 if (!succeeded(getStridesAndOffset(stridesFull, offset)))
662 return false;
663 auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
664
665 if (strides.empty())
666 return true;
667
668 // Check whether strides match "flattened" dims.
669 SmallVector<int64_t> flattenedDims;
670 auto dimProduct = 1;
671 for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
672 dimProduct *= dim;
673 flattenedDims.push_back(dimProduct);
674 }
675
676 strides = strides.drop_back(1);
677 return llvm::equal(strides, llvm::reverse(flattenedDims));
678}
679
680MemRefType MemRefType::canonicalizeStridedLayout() {
681 AffineMap m = getLayout().getAffineMap();
682
683 // Already in canonical form.
684 if (m.isIdentity())
685 return *this;
686
687 // Can't reduce to canonical identity form, return in canonical form.
688 if (m.getNumResults() > 1)
689 return *this;
690
691 // Corner-case for 0-D affine maps.
692 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
693 if (auto cst = llvm::dyn_cast<AffineConstantExpr>(m.getResult(0)))
694 if (cst.getValue() == 0)
695 return MemRefType::Builder(*this).setLayout({});
696 return *this;
697 }
698
699 // 0-D corner case for empty shape that still have an affine map. Example:
700 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
701 // offset needs to remain, just return t.
702 if (getShape().empty())
703 return *this;
704
705 // If the canonical strided layout for the sizes of `t` is equal to the
706 // simplified layout of `t` we can just return an empty layout. Otherwise,
707 // just simplify the existing layout.
708 AffineExpr expr = makeCanonicalStridedLayoutExpr(getShape(), getContext());
709 auto simplifiedLayoutExpr =
710 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
711 if (expr != simplifiedLayoutExpr)
712 return MemRefType::Builder(*this).setLayout(
713 AffineMapAttr::get(AffineMap::get(m.getNumDims(), m.getNumSymbols(),
714 simplifiedLayoutExpr)));
715 return MemRefType::Builder(*this).setLayout({});
716}
717
718LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
719 int64_t &offset) {
720 return getLayout().getStridesAndOffset(getShape(), strides, offset);
721}
722
723std::pair<SmallVector<int64_t>, int64_t> MemRefType::getStridesAndOffset() {
724 SmallVector<int64_t> strides;
725 int64_t offset;
726 LogicalResult status = getStridesAndOffset(strides, offset);
727 (void)status;
728 assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
729 return {strides, offset};
730}
731
732bool MemRefType::isStrided() {
733 int64_t offset;
734 SmallVector<int64_t, 4> strides;
735 auto res = getStridesAndOffset(strides, offset);
736 return succeeded(res);
737}
738
739bool MemRefType::isLastDimUnitStride() {
740 int64_t offset;
741 SmallVector<int64_t> strides;
742 auto successStrides = getStridesAndOffset(strides, offset);
743 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
744}
745
746//===----------------------------------------------------------------------===//
747// UnrankedMemRefType
748//===----------------------------------------------------------------------===//
749
750unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
751 return detail::getMemorySpaceAsInt(getMemorySpace());
752}
753
754LogicalResult
755UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
756 Type elementType, Attribute memorySpace) {
757 if (!BaseMemRefType::isValidElementType(elementType))
758 return emitError() << "invalid memref element type";
759
760 if (!isSupportedMemorySpace(memorySpace))
761 return emitError() << "unsupported memory space Attribute";
762
763 return success();
764}
765
766//===----------------------------------------------------------------------===//
767/// TupleType
768//===----------------------------------------------------------------------===//
769
770/// Return the elements types for this tuple.
771ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
772
773/// Accumulate the types contained in this tuple and tuples nested within it.
774/// Note that this only flattens nested tuples, not any other container type,
775/// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
776/// (i32, tensor<i32>, f32, i64)
777void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
778 for (Type type : getTypes()) {
779 if (auto nestedTuple = llvm::dyn_cast<TupleType>(type))
780 nestedTuple.getFlattenedTypes(types);
781 else
782 types.push_back(type);
783 }
784}
785
786/// Return the number of element types.
787size_t TupleType::size() const { return getImpl()->size(); }
788
789//===----------------------------------------------------------------------===//
790// Type Utilities
791//===----------------------------------------------------------------------===//
792
793AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
794 ArrayRef<AffineExpr> exprs,
795 MLIRContext *context) {
796 // Size 0 corner case is useful for canonicalizations.
797 if (sizes.empty())
798 return getAffineConstantExpr(constant: 0, context);
799
800 assert(!exprs.empty() && "expected exprs");
801 auto maps = AffineMap::inferFromExprList(exprsList: exprs, context);
802 assert(!maps.empty() && "Expected one non-empty map");
803 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
804
805 AffineExpr expr;
806 bool dynamicPoisonBit = false;
807 int64_t runningSize = 1;
808 for (auto en : llvm::zip(t: llvm::reverse(C&: exprs), u: llvm::reverse(C&: sizes))) {
809 int64_t size = std::get<1>(t&: en);
810 AffineExpr dimExpr = std::get<0>(t&: en);
811 AffineExpr stride = dynamicPoisonBit
812 ? getAffineSymbolExpr(position: nSymbols++, context)
813 : getAffineConstantExpr(constant: runningSize, context);
814 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
815 if (size > 0) {
816 runningSize *= size;
817 assert(runningSize > 0 && "integer overflow in size computation");
818 } else {
819 dynamicPoisonBit = true;
820 }
821 }
822 return simplifyAffineExpr(expr, numDims, numSymbols: nSymbols);
823}
824
825AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
826 MLIRContext *context) {
827 SmallVector<AffineExpr, 4> exprs;
828 exprs.reserve(N: sizes.size());
829 for (auto dim : llvm::seq<unsigned>(Begin: 0, End: sizes.size()))
830 exprs.push_back(Elt: getAffineDimExpr(position: dim, context));
831 return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
832}
833

source code of mlir/lib/IR/BuiltinTypes.cpp