1//===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/BuiltinTypes.h"
10#include "TypeDetail.h"
11#include "mlir/IR/AffineExpr.h"
12#include "mlir/IR/AffineMap.h"
13#include "mlir/IR/BuiltinAttributes.h"
14#include "mlir/IR/BuiltinDialect.h"
15#include "mlir/IR/Diagnostics.h"
16#include "mlir/IR/Dialect.h"
17#include "mlir/IR/TensorEncoding.h"
18#include "mlir/IR/TypeUtilities.h"
19#include "llvm/ADT/APFloat.h"
20#include "llvm/ADT/Sequence.h"
21#include "llvm/ADT/TypeSwitch.h"
22
23using namespace mlir;
24using namespace mlir::detail;
25
26//===----------------------------------------------------------------------===//
27/// Tablegen Type Definitions
28//===----------------------------------------------------------------------===//
29
30#define GET_TYPEDEF_CLASSES
31#include "mlir/IR/BuiltinTypes.cpp.inc"
32
33namespace mlir {
34#include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
35} // namespace mlir
36
37//===----------------------------------------------------------------------===//
38// BuiltinDialect
39//===----------------------------------------------------------------------===//
40
41void BuiltinDialect::registerTypes() {
42 addTypes<
43#define GET_TYPEDEF_LIST
44#include "mlir/IR/BuiltinTypes.cpp.inc"
45 >();
46}
47
48//===----------------------------------------------------------------------===//
49/// ComplexType
50//===----------------------------------------------------------------------===//
51
52/// Verify the construction of an integer type.
53LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
54 Type elementType) {
55 if (!elementType.isIntOrFloat())
56 return emitError() << "invalid element type for complex";
57 return success();
58}
59
60//===----------------------------------------------------------------------===//
61// Integer Type
62//===----------------------------------------------------------------------===//
63
64/// Verify the construction of an integer type.
65LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
66 unsigned width,
67 SignednessSemantics signedness) {
68 if (width > IntegerType::kMaxWidth) {
69 return emitError() << "integer bitwidth is limited to "
70 << IntegerType::kMaxWidth << " bits";
71 }
72 return success();
73}
74
75unsigned IntegerType::getWidth() const { return getImpl()->width; }
76
77IntegerType::SignednessSemantics IntegerType::getSignedness() const {
78 return getImpl()->signedness;
79}
80
81IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
82 if (!scale)
83 return IntegerType();
84 return IntegerType::get(context: getContext(), width: scale * getWidth(), signedness: getSignedness());
85}
86
87//===----------------------------------------------------------------------===//
88// Float Types
89//===----------------------------------------------------------------------===//
90
91// Mapping from MLIR FloatType to APFloat semantics.
92#define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \
93 const llvm::fltSemantics &TYPE::getFloatSemantics() const { \
94 return APFloat::SEM(); \
95 }
96FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN)
97FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN)
98FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN)
99FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2)
100FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3)
101FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN)
102FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ)
103FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ)
104FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ)
105FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4)
106FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU)
107FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat)
108FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf)
109FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32)
110FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle)
111FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble)
112FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended)
113FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad)
114#undef FLOAT_TYPE_SEMANTICS
115
116FloatType Float16Type::scaleElementBitwidth(unsigned scale) const {
117 if (scale == 2)
118 return Float32Type::get(context: getContext());
119 if (scale == 4)
120 return Float64Type::get(context: getContext());
121 return FloatType();
122}
123
124FloatType BFloat16Type::scaleElementBitwidth(unsigned scale) const {
125 if (scale == 2)
126 return Float32Type::get(context: getContext());
127 if (scale == 4)
128 return Float64Type::get(context: getContext());
129 return FloatType();
130}
131
132FloatType Float32Type::scaleElementBitwidth(unsigned scale) const {
133 if (scale == 2)
134 return Float64Type::get(context: getContext());
135 return FloatType();
136}
137
138//===----------------------------------------------------------------------===//
139// FunctionType
140//===----------------------------------------------------------------------===//
141
142unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
143
144ArrayRef<Type> FunctionType::getInputs() const {
145 return getImpl()->getInputs();
146}
147
148unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
149
150ArrayRef<Type> FunctionType::getResults() const {
151 return getImpl()->getResults();
152}
153
154FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
155 return get(context: getContext(), inputs, results);
156}
157
158/// Returns a new function type with the specified arguments and results
159/// inserted.
160FunctionType FunctionType::getWithArgsAndResults(
161 ArrayRef<unsigned> argIndices, TypeRange argTypes,
162 ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
163 SmallVector<Type> argStorage, resultStorage;
164 TypeRange newArgTypes =
165 insertTypesInto(oldTypes: getInputs(), indices: argIndices, newTypes: argTypes, storage&: argStorage);
166 TypeRange newResultTypes =
167 insertTypesInto(oldTypes: getResults(), indices: resultIndices, newTypes: resultTypes, storage&: resultStorage);
168 return clone(inputs: newArgTypes, results: newResultTypes);
169}
170
171/// Returns a new function type without the specified arguments and results.
172FunctionType
173FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
174 const BitVector &resultIndices) {
175 SmallVector<Type> argStorage, resultStorage;
176 TypeRange newArgTypes = filterTypesOut(types: getInputs(), indices: argIndices, storage&: argStorage);
177 TypeRange newResultTypes =
178 filterTypesOut(types: getResults(), indices: resultIndices, storage&: resultStorage);
179 return clone(inputs: newArgTypes, results: newResultTypes);
180}
181
182//===----------------------------------------------------------------------===//
183// OpaqueType
184//===----------------------------------------------------------------------===//
185
186/// Verify the construction of an opaque type.
187LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
188 StringAttr dialect, StringRef typeData) {
189 if (!Dialect::isValidNamespace(str: dialect.strref()))
190 return emitError() << "invalid dialect namespace '" << dialect << "'";
191
192 // Check that the dialect is actually registered.
193 MLIRContext *context = dialect.getContext();
194 if (!context->allowsUnregisteredDialects() &&
195 !context->getLoadedDialect(name: dialect.strref())) {
196 return emitError()
197 << "`!" << dialect << "<\"" << typeData << "\">"
198 << "` type created with unregistered dialect. If this is "
199 "intended, please call allowUnregisteredDialects() on the "
200 "MLIRContext, or use -allow-unregistered-dialect with "
201 "the MLIR opt tool used";
202 }
203
204 return success();
205}
206
207//===----------------------------------------------------------------------===//
208// VectorType
209//===----------------------------------------------------------------------===//
210
211bool VectorType::isValidElementType(Type t) {
212 return isValidVectorTypeElementType(type: t);
213}
214
215LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
216 ArrayRef<int64_t> shape, Type elementType,
217 ArrayRef<bool> scalableDims) {
218 if (!isValidElementType(t: elementType))
219 return emitError()
220 << "vector elements must be int/index/float type but got "
221 << elementType;
222
223 if (any_of(Range&: shape, P: [](int64_t i) { return i <= 0; }))
224 return emitError()
225 << "vector types must have positive constant sizes but got "
226 << shape;
227
228 if (scalableDims.size() != shape.size())
229 return emitError() << "number of dims must match, got "
230 << scalableDims.size() << " and " << shape.size();
231
232 return success();
233}
234
235VectorType VectorType::scaleElementBitwidth(unsigned scale) {
236 if (!scale)
237 return VectorType();
238 if (auto et = llvm::dyn_cast<IntegerType>(Val: getElementType()))
239 if (auto scaledEt = et.scaleElementBitwidth(scale))
240 return VectorType::get(shape: getShape(), elementType: scaledEt, scalableDims: getScalableDims());
241 if (auto et = llvm::dyn_cast<FloatType>(Val: getElementType()))
242 if (auto scaledEt = et.scaleElementBitwidth(scale))
243 return VectorType::get(shape: getShape(), elementType: scaledEt, scalableDims: getScalableDims());
244 return VectorType();
245}
246
247VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
248 Type elementType) const {
249 return VectorType::get(shape: shape.value_or(u: getShape()), elementType,
250 scalableDims: getScalableDims());
251}
252
253//===----------------------------------------------------------------------===//
254// TensorType
255//===----------------------------------------------------------------------===//
256
257Type TensorType::getElementType() const {
258 return llvm::TypeSwitch<TensorType, Type>(*this)
259 .Case<RankedTensorType, UnrankedTensorType>(
260 caseFn: [](auto type) { return type.getElementType(); });
261}
262
263bool TensorType::hasRank() const {
264 return !llvm::isa<UnrankedTensorType>(Val: *this);
265}
266
267ArrayRef<int64_t> TensorType::getShape() const {
268 return llvm::cast<RankedTensorType>(Val: *this).getShape();
269}
270
271TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
272 Type elementType) const {
273 if (llvm::dyn_cast<UnrankedTensorType>(Val: *this)) {
274 if (shape)
275 return RankedTensorType::get(shape: *shape, elementType);
276 return UnrankedTensorType::get(elementType);
277 }
278
279 auto rankedTy = llvm::cast<RankedTensorType>(Val: *this);
280 if (!shape)
281 return RankedTensorType::get(shape: rankedTy.getShape(), elementType,
282 encoding: rankedTy.getEncoding());
283 return RankedTensorType::get(shape: shape.value_or(u: rankedTy.getShape()), elementType,
284 encoding: rankedTy.getEncoding());
285}
286
287RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
288 Type elementType) const {
289 return ::llvm::cast<RankedTensorType>(Val: cloneWith(shape, elementType));
290}
291
292RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
293 return ::llvm::cast<RankedTensorType>(Val: cloneWith(shape, elementType: getElementType()));
294}
295
296// Check if "elementType" can be an element type of a tensor.
297static LogicalResult
298checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
299 Type elementType) {
300 if (!TensorType::isValidElementType(type: elementType))
301 return emitError() << "invalid tensor element type: " << elementType;
302 return success();
303}
304
305/// Return true if the specified element type is ok in a tensor.
306bool TensorType::isValidElementType(Type type) {
307 // Note: Non standard/builtin types are allowed to exist within tensor
308 // types. Dialects are expected to verify that tensor types have a valid
309 // element type within that dialect.
310 return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
311 IndexType>(Val: type) ||
312 !llvm::isa<BuiltinDialect>(Val: type.getDialect());
313}
314
315//===----------------------------------------------------------------------===//
316// RankedTensorType
317//===----------------------------------------------------------------------===//
318
319LogicalResult
320RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
321 ArrayRef<int64_t> shape, Type elementType,
322 Attribute encoding) {
323 for (int64_t s : shape)
324 if (s < 0 && ShapedType::isStatic(dValue: s))
325 return emitError() << "invalid tensor dimension size";
326 if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(Val&: encoding))
327 if (failed(Result: v.verifyEncoding(shape, elementType, emitError)))
328 return failure();
329 return checkTensorElementType(emitError, elementType);
330}
331
332//===----------------------------------------------------------------------===//
333// UnrankedTensorType
334//===----------------------------------------------------------------------===//
335
336LogicalResult
337UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
338 Type elementType) {
339 return checkTensorElementType(emitError, elementType);
340}
341
342//===----------------------------------------------------------------------===//
343// BaseMemRefType
344//===----------------------------------------------------------------------===//
345
346Type BaseMemRefType::getElementType() const {
347 return llvm::TypeSwitch<BaseMemRefType, Type>(*this)
348 .Case<MemRefType, UnrankedMemRefType>(
349 caseFn: [](auto type) { return type.getElementType(); });
350}
351
352bool BaseMemRefType::hasRank() const {
353 return !llvm::isa<UnrankedMemRefType>(Val: *this);
354}
355
356ArrayRef<int64_t> BaseMemRefType::getShape() const {
357 return llvm::cast<MemRefType>(Val: *this).getShape();
358}
359
360BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
361 Type elementType) const {
362 if (llvm::dyn_cast<UnrankedMemRefType>(Val: *this)) {
363 if (!shape)
364 return UnrankedMemRefType::get(elementType, memorySpace: getMemorySpace());
365 MemRefType::Builder builder(*shape, elementType);
366 builder.setMemorySpace(getMemorySpace());
367 return builder;
368 }
369
370 MemRefType::Builder builder(llvm::cast<MemRefType>(Val: *this));
371 if (shape)
372 builder.setShape(*shape);
373 builder.setElementType(elementType);
374 return builder;
375}
376
377FailureOr<PtrLikeTypeInterface>
378BaseMemRefType::clonePtrWith(Attribute memorySpace,
379 std::optional<Type> elementType) const {
380 Type eTy = elementType ? *elementType : getElementType();
381 if (llvm::dyn_cast<UnrankedMemRefType>(Val: *this))
382 return cast<PtrLikeTypeInterface>(
383 Val: UnrankedMemRefType::get(elementType: eTy, memorySpace));
384
385 MemRefType::Builder builder(llvm::cast<MemRefType>(Val: *this));
386 builder.setElementType(eTy);
387 builder.setMemorySpace(memorySpace);
388 return cast<PtrLikeTypeInterface>(Val: static_cast<MemRefType>(builder));
389}
390
391MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape,
392 Type elementType) const {
393 return ::llvm::cast<MemRefType>(Val: cloneWith(shape, elementType));
394}
395
396MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
397 return ::llvm::cast<MemRefType>(Val: cloneWith(shape, elementType: getElementType()));
398}
399
400Attribute BaseMemRefType::getMemorySpace() const {
401 if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(Val: *this))
402 return rankedMemRefTy.getMemorySpace();
403 return llvm::cast<UnrankedMemRefType>(Val: *this).getMemorySpace();
404}
405
406unsigned BaseMemRefType::getMemorySpaceAsInt() const {
407 if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(Val: *this))
408 return rankedMemRefTy.getMemorySpaceAsInt();
409 return llvm::cast<UnrankedMemRefType>(Val: *this).getMemorySpaceAsInt();
410}
411
412//===----------------------------------------------------------------------===//
413// MemRefType
414//===----------------------------------------------------------------------===//
415
416std::optional<llvm::SmallDenseSet<unsigned>>
417mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
418 ArrayRef<int64_t> reducedShape,
419 bool matchDynamic) {
420 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
421 llvm::SmallDenseSet<unsigned> unusedDims;
422 unsigned reducedIdx = 0;
423 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
424 // Greedily insert `originalIdx` if match.
425 int64_t origSize = originalShape[originalIdx];
426 // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
427 if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
428 (ShapedType::isDynamic(dValue: reducedShape[reducedIdx]) ||
429 ShapedType::isDynamic(dValue: origSize))) {
430 reducedIdx++;
431 continue;
432 }
433 if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
434 reducedIdx++;
435 continue;
436 }
437
438 unusedDims.insert(V: originalIdx);
439 // If no match on `originalIdx`, the `originalShape` at this dimension
440 // must be 1, otherwise we bail.
441 if (origSize != 1)
442 return std::nullopt;
443 }
444 // The whole reducedShape must be scanned, otherwise we bail.
445 if (reducedIdx != reducedRank)
446 return std::nullopt;
447 return unusedDims;
448}
449
450SliceVerificationResult
451mlir::isRankReducedType(ShapedType originalType,
452 ShapedType candidateReducedType) {
453 if (originalType == candidateReducedType)
454 return SliceVerificationResult::Success;
455
456 ShapedType originalShapedType = llvm::cast<ShapedType>(Val&: originalType);
457 ShapedType candidateReducedShapedType =
458 llvm::cast<ShapedType>(Val&: candidateReducedType);
459
460 // Rank and size logic is valid for all ShapedTypes.
461 ArrayRef<int64_t> originalShape = originalShapedType.getShape();
462 ArrayRef<int64_t> candidateReducedShape =
463 candidateReducedShapedType.getShape();
464 unsigned originalRank = originalShape.size(),
465 candidateReducedRank = candidateReducedShape.size();
466 if (candidateReducedRank > originalRank)
467 return SliceVerificationResult::RankTooLarge;
468
469 auto optionalUnusedDimsMask =
470 computeRankReductionMask(originalShape, reducedShape: candidateReducedShape);
471
472 // Sizes cannot be matched in case empty vector is returned.
473 if (!optionalUnusedDimsMask)
474 return SliceVerificationResult::SizeMismatch;
475
476 if (originalShapedType.getElementType() !=
477 candidateReducedShapedType.getElementType())
478 return SliceVerificationResult::ElemTypeMismatch;
479
480 return SliceVerificationResult::Success;
481}
482
483bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
484 // Empty attribute is allowed as default memory space.
485 if (!memorySpace)
486 return true;
487
488 // Supported built-in attributes.
489 if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(Val: memorySpace))
490 return true;
491
492 // Allow custom dialect attributes.
493 if (!isa<BuiltinDialect>(Val: memorySpace.getDialect()))
494 return true;
495
496 return false;
497}
498
499Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
500 MLIRContext *ctx) {
501 if (memorySpace == 0)
502 return nullptr;
503
504 return IntegerAttr::get(type: IntegerType::get(context: ctx, width: 64), value: memorySpace);
505}
506
507Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
508 IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(Val&: memorySpace);
509 if (intMemorySpace && intMemorySpace.getValue() == 0)
510 return nullptr;
511
512 return memorySpace;
513}
514
515unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
516 if (!memorySpace)
517 return 0;
518
519 assert(llvm::isa<IntegerAttr>(memorySpace) &&
520 "Using `getMemorySpaceInteger` with non-Integer attribute");
521
522 return static_cast<unsigned>(llvm::cast<IntegerAttr>(Val&: memorySpace).getInt());
523}
524
525unsigned MemRefType::getMemorySpaceAsInt() const {
526 return detail::getMemorySpaceAsInt(memorySpace: getMemorySpace());
527}
528
529MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
530 MemRefLayoutAttrInterface layout,
531 Attribute memorySpace) {
532 // Use default layout for empty attribute.
533 if (!layout)
534 layout = AffineMapAttr::get(value: AffineMap::getMultiDimIdentityMap(
535 numDims: shape.size(), context: elementType.getContext()));
536
537 // Drop default memory space value and replace it with empty attribute.
538 memorySpace = skipDefaultMemorySpace(memorySpace);
539
540 return Base::get(ctx: elementType.getContext(), args&: shape, args&: elementType, args&: layout,
541 args&: memorySpace);
542}
543
544MemRefType MemRefType::getChecked(
545 function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
546 Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
547
548 // Use default layout for empty attribute.
549 if (!layout)
550 layout = AffineMapAttr::get(value: AffineMap::getMultiDimIdentityMap(
551 numDims: shape.size(), context: elementType.getContext()));
552
553 // Drop default memory space value and replace it with empty attribute.
554 memorySpace = skipDefaultMemorySpace(memorySpace);
555
556 return Base::getChecked(emitErrorFn, ctx: elementType.getContext(), args: shape,
557 args: elementType, args: layout, args: memorySpace);
558}
559
560MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
561 AffineMap map, Attribute memorySpace) {
562
563 // Use default layout for empty map.
564 if (!map)
565 map = AffineMap::getMultiDimIdentityMap(numDims: shape.size(),
566 context: elementType.getContext());
567
568 // Wrap AffineMap into Attribute.
569 auto layout = AffineMapAttr::get(value: map);
570
571 // Drop default memory space value and replace it with empty attribute.
572 memorySpace = skipDefaultMemorySpace(memorySpace);
573
574 return Base::get(ctx: elementType.getContext(), args&: shape, args&: elementType, args&: layout,
575 args&: memorySpace);
576}
577
578MemRefType
579MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
580 ArrayRef<int64_t> shape, Type elementType, AffineMap map,
581 Attribute memorySpace) {
582
583 // Use default layout for empty map.
584 if (!map)
585 map = AffineMap::getMultiDimIdentityMap(numDims: shape.size(),
586 context: elementType.getContext());
587
588 // Wrap AffineMap into Attribute.
589 auto layout = AffineMapAttr::get(value: map);
590
591 // Drop default memory space value and replace it with empty attribute.
592 memorySpace = skipDefaultMemorySpace(memorySpace);
593
594 return Base::getChecked(emitErrorFn, ctx: elementType.getContext(), args: shape,
595 args: elementType, args: layout, args: memorySpace);
596}
597
598MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
599 AffineMap map, unsigned memorySpaceInd) {
600
601 // Use default layout for empty map.
602 if (!map)
603 map = AffineMap::getMultiDimIdentityMap(numDims: shape.size(),
604 context: elementType.getContext());
605
606 // Wrap AffineMap into Attribute.
607 auto layout = AffineMapAttr::get(value: map);
608
609 // Convert deprecated integer-like memory space to Attribute.
610 Attribute memorySpace =
611 wrapIntegerMemorySpace(memorySpace: memorySpaceInd, ctx: elementType.getContext());
612
613 return Base::get(ctx: elementType.getContext(), args&: shape, args&: elementType, args&: layout,
614 args&: memorySpace);
615}
616
617MemRefType
618MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
619 ArrayRef<int64_t> shape, Type elementType, AffineMap map,
620 unsigned memorySpaceInd) {
621
622 // Use default layout for empty map.
623 if (!map)
624 map = AffineMap::getMultiDimIdentityMap(numDims: shape.size(),
625 context: elementType.getContext());
626
627 // Wrap AffineMap into Attribute.
628 auto layout = AffineMapAttr::get(value: map);
629
630 // Convert deprecated integer-like memory space to Attribute.
631 Attribute memorySpace =
632 wrapIntegerMemorySpace(memorySpace: memorySpaceInd, ctx: elementType.getContext());
633
634 return Base::getChecked(emitErrorFn, ctx: elementType.getContext(), args: shape,
635 args: elementType, args: layout, args: memorySpace);
636}
637
638LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
639 ArrayRef<int64_t> shape, Type elementType,
640 MemRefLayoutAttrInterface layout,
641 Attribute memorySpace) {
642 if (!BaseMemRefType::isValidElementType(type: elementType))
643 return emitError() << "invalid memref element type";
644
645 // Negative sizes are not allowed except for `kDynamic`.
646 for (int64_t s : shape)
647 if (s < 0 && ShapedType::isStatic(dValue: s))
648 return emitError() << "invalid memref size";
649
650 assert(layout && "missing layout specification");
651 if (failed(Result: layout.verifyLayout(shape, emitError)))
652 return failure();
653
654 if (!isSupportedMemorySpace(memorySpace))
655 return emitError() << "unsupported memory space Attribute";
656
657 return success();
658}
659
660bool MemRefType::areTrailingDimsContiguous(int64_t n) {
661 assert(n <= getRank() &&
662 "number of dimensions to check must not exceed rank");
663 return n <= getNumContiguousTrailingDims();
664}
665
666int64_t MemRefType::getNumContiguousTrailingDims() {
667 const int64_t n = getRank();
668
669 // memrefs with identity layout are entirely contiguous.
670 if (getLayout().isIdentity())
671 return n;
672
673 // Get the strides (if any). Failing to do that, conservatively assume a
674 // non-contiguous layout.
675 int64_t offset;
676 SmallVector<int64_t> strides;
677 if (!succeeded(Result: getStridesAndOffset(strides, offset)))
678 return 0;
679
680 ArrayRef<int64_t> shape = getShape();
681
682 // A memref with dimensions `d0, d1, ..., dn-1` and strides
683 // `s0, s1, ..., sn-1` is contiguous up to dimension `k`
684 // if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
685 // for `i` in `[k, n-1]`.
686 // Ignore stride elements if the corresponding dimension is 1, as they are
687 // of no consequence.
688 int64_t dimProduct = 1;
689 for (int64_t i = n - 1; i >= 0; --i) {
690 if (shape[i] == 1)
691 continue;
692 if (strides[i] != dimProduct)
693 return n - i - 1;
694 if (shape[i] == ShapedType::kDynamic)
695 return n - i;
696 dimProduct *= shape[i];
697 }
698
699 return n;
700}
701
702MemRefType MemRefType::canonicalizeStridedLayout() {
703 AffineMap m = getLayout().getAffineMap();
704
705 // Already in canonical form.
706 if (m.isIdentity())
707 return *this;
708
709 // Can't reduce to canonical identity form, return in canonical form.
710 if (m.getNumResults() > 1)
711 return *this;
712
713 // Corner-case for 0-D affine maps.
714 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
715 if (auto cst = llvm::dyn_cast<AffineConstantExpr>(Val: m.getResult(idx: 0)))
716 if (cst.getValue() == 0)
717 return MemRefType::Builder(*this).setLayout({});
718 return *this;
719 }
720
721 // 0-D corner case for empty shape that still have an affine map. Example:
722 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
723 // offset needs to remain, just return t.
724 if (getShape().empty())
725 return *this;
726
727 // If the canonical strided layout for the sizes of `t` is equal to the
728 // simplified layout of `t` we can just return an empty layout. Otherwise,
729 // just simplify the existing layout.
730 AffineExpr expr = makeCanonicalStridedLayoutExpr(sizes: getShape(), context: getContext());
731 auto simplifiedLayoutExpr =
732 simplifyAffineExpr(expr: m.getResult(idx: 0), numDims: m.getNumDims(), numSymbols: m.getNumSymbols());
733 if (expr != simplifiedLayoutExpr)
734 return MemRefType::Builder(*this).setLayout(
735 AffineMapAttr::get(value: AffineMap::get(dimCount: m.getNumDims(), symbolCount: m.getNumSymbols(),
736 result: simplifiedLayoutExpr)));
737 return MemRefType::Builder(*this).setLayout({});
738}
739
740LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
741 int64_t &offset) const {
742 return getLayout().getStridesAndOffset(shape: getShape(), strides, offset);
743}
744
745std::pair<SmallVector<int64_t>, int64_t>
746MemRefType::getStridesAndOffset() const {
747 SmallVector<int64_t> strides;
748 int64_t offset;
749 LogicalResult status = getStridesAndOffset(strides, offset);
750 (void)status;
751 assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
752 return {strides, offset};
753}
754
755bool MemRefType::isStrided() {
756 int64_t offset;
757 SmallVector<int64_t, 4> strides;
758 auto res = getStridesAndOffset(strides, offset);
759 return succeeded(Result: res);
760}
761
762bool MemRefType::isLastDimUnitStride() {
763 int64_t offset;
764 SmallVector<int64_t> strides;
765 auto successStrides = getStridesAndOffset(strides, offset);
766 return succeeded(Result: successStrides) && (strides.empty() || strides.back() == 1);
767}
768
769//===----------------------------------------------------------------------===//
770// UnrankedMemRefType
771//===----------------------------------------------------------------------===//
772
773unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
774 return detail::getMemorySpaceAsInt(memorySpace: getMemorySpace());
775}
776
777LogicalResult
778UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
779 Type elementType, Attribute memorySpace) {
780 if (!BaseMemRefType::isValidElementType(type: elementType))
781 return emitError() << "invalid memref element type";
782
783 if (!isSupportedMemorySpace(memorySpace))
784 return emitError() << "unsupported memory space Attribute";
785
786 return success();
787}
788
789//===----------------------------------------------------------------------===//
790/// TupleType
791//===----------------------------------------------------------------------===//
792
793/// Return the elements types for this tuple.
794ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
795
796/// Accumulate the types contained in this tuple and tuples nested within it.
797/// Note that this only flattens nested tuples, not any other container type,
798/// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
799/// (i32, tensor<i32>, f32, i64)
800void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
801 for (Type type : getTypes()) {
802 if (auto nestedTuple = llvm::dyn_cast<TupleType>(Val&: type))
803 nestedTuple.getFlattenedTypes(types);
804 else
805 types.push_back(Elt: type);
806 }
807}
808
809/// Return the number of element types.
810size_t TupleType::size() const { return getImpl()->size(); }
811
812//===----------------------------------------------------------------------===//
813// Type Utilities
814//===----------------------------------------------------------------------===//
815
816AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
817 ArrayRef<AffineExpr> exprs,
818 MLIRContext *context) {
819 // Size 0 corner case is useful for canonicalizations.
820 if (sizes.empty())
821 return getAffineConstantExpr(constant: 0, context);
822
823 assert(!exprs.empty() && "expected exprs");
824 auto maps = AffineMap::inferFromExprList(exprsList: exprs, context);
825 assert(!maps.empty() && "Expected one non-empty map");
826 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
827
828 AffineExpr expr;
829 bool dynamicPoisonBit = false;
830 int64_t runningSize = 1;
831 for (auto en : llvm::zip(t: llvm::reverse(C&: exprs), u: llvm::reverse(C&: sizes))) {
832 int64_t size = std::get<1>(t&: en);
833 AffineExpr dimExpr = std::get<0>(t&: en);
834 AffineExpr stride = dynamicPoisonBit
835 ? getAffineSymbolExpr(position: nSymbols++, context)
836 : getAffineConstantExpr(constant: runningSize, context);
837 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
838 if (size > 0) {
839 runningSize *= size;
840 assert(runningSize > 0 && "integer overflow in size computation");
841 } else {
842 dynamicPoisonBit = true;
843 }
844 }
845 return simplifyAffineExpr(expr, numDims, numSymbols: nSymbols);
846}
847
848AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
849 MLIRContext *context) {
850 SmallVector<AffineExpr, 4> exprs;
851 exprs.reserve(N: sizes.size());
852 for (auto dim : llvm::seq<unsigned>(Begin: 0, End: sizes.size()))
853 exprs.push_back(Elt: getAffineDimExpr(position: dim, context));
854 return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
855}
856

source code of mlir/lib/IR/BuiltinTypes.cpp