1//===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/BuiltinTypes.h"
10#include "TypeDetail.h"
11#include "mlir/IR/AffineExpr.h"
12#include "mlir/IR/AffineMap.h"
13#include "mlir/IR/BuiltinAttributes.h"
14#include "mlir/IR/BuiltinDialect.h"
15#include "mlir/IR/Diagnostics.h"
16#include "mlir/IR/Dialect.h"
17#include "mlir/IR/TensorEncoding.h"
18#include "mlir/IR/TypeUtilities.h"
19#include "llvm/ADT/APFloat.h"
20#include "llvm/ADT/BitVector.h"
21#include "llvm/ADT/Sequence.h"
22#include "llvm/ADT/Twine.h"
23#include "llvm/ADT/TypeSwitch.h"
24
25using namespace mlir;
26using namespace mlir::detail;
27
28//===----------------------------------------------------------------------===//
29/// Tablegen Type Definitions
30//===----------------------------------------------------------------------===//
31
32#define GET_TYPEDEF_CLASSES
33#include "mlir/IR/BuiltinTypes.cpp.inc"
34
35//===----------------------------------------------------------------------===//
36// BuiltinDialect
37//===----------------------------------------------------------------------===//
38
39void BuiltinDialect::registerTypes() {
40 addTypes<
41#define GET_TYPEDEF_LIST
42#include "mlir/IR/BuiltinTypes.cpp.inc"
43 >();
44}
45
46//===----------------------------------------------------------------------===//
47/// ComplexType
48//===----------------------------------------------------------------------===//
49
50/// Verify the construction of an integer type.
51LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
52 Type elementType) {
53 if (!elementType.isIntOrFloat())
54 return emitError() << "invalid element type for complex";
55 return success();
56}
57
58//===----------------------------------------------------------------------===//
59// Integer Type
60//===----------------------------------------------------------------------===//
61
62/// Verify the construction of an integer type.
63LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
64 unsigned width,
65 SignednessSemantics signedness) {
66 if (width > IntegerType::kMaxWidth) {
67 return emitError() << "integer bitwidth is limited to "
68 << IntegerType::kMaxWidth << " bits";
69 }
70 return success();
71}
72
73unsigned IntegerType::getWidth() const { return getImpl()->width; }
74
75IntegerType::SignednessSemantics IntegerType::getSignedness() const {
76 return getImpl()->signedness;
77}
78
79IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
80 if (!scale)
81 return IntegerType();
82 return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
83}
84
85//===----------------------------------------------------------------------===//
86// Float Type
87//===----------------------------------------------------------------------===//
88
89unsigned FloatType::getWidth() {
90 if (llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
91 Float8E4M3FNUZType, Float8E4M3B11FNUZType>(*this))
92 return 8;
93 if (llvm::isa<Float16Type, BFloat16Type>(*this))
94 return 16;
95 if (llvm::isa<Float32Type, FloatTF32Type>(*this))
96 return 32;
97 if (llvm::isa<Float64Type>(*this))
98 return 64;
99 if (llvm::isa<Float80Type>(*this))
100 return 80;
101 if (llvm::isa<Float128Type>(*this))
102 return 128;
103 llvm_unreachable("unexpected float type");
104}
105
106/// Returns the floating semantics for the given type.
107const llvm::fltSemantics &FloatType::getFloatSemantics() {
108 if (llvm::isa<Float8E5M2Type>(*this))
109 return APFloat::Float8E5M2();
110 if (llvm::isa<Float8E4M3FNType>(*this))
111 return APFloat::Float8E4M3FN();
112 if (llvm::isa<Float8E5M2FNUZType>(*this))
113 return APFloat::Float8E5M2FNUZ();
114 if (llvm::isa<Float8E4M3FNUZType>(*this))
115 return APFloat::Float8E4M3FNUZ();
116 if (llvm::isa<Float8E4M3B11FNUZType>(*this))
117 return APFloat::Float8E4M3B11FNUZ();
118 if (llvm::isa<BFloat16Type>(*this))
119 return APFloat::BFloat();
120 if (llvm::isa<Float16Type>(*this))
121 return APFloat::IEEEhalf();
122 if (llvm::isa<FloatTF32Type>(*this))
123 return APFloat::FloatTF32();
124 if (llvm::isa<Float32Type>(*this))
125 return APFloat::IEEEsingle();
126 if (llvm::isa<Float64Type>(*this))
127 return APFloat::IEEEdouble();
128 if (llvm::isa<Float80Type>(*this))
129 return APFloat::x87DoubleExtended();
130 if (llvm::isa<Float128Type>(*this))
131 return APFloat::IEEEquad();
132 llvm_unreachable("non-floating point type used");
133}
134
135FloatType FloatType::scaleElementBitwidth(unsigned scale) {
136 if (!scale)
137 return FloatType();
138 MLIRContext *ctx = getContext();
139 if (isF16() || isBF16()) {
140 if (scale == 2)
141 return FloatType::getF32(ctx);
142 if (scale == 4)
143 return FloatType::getF64(ctx);
144 }
145 if (isF32())
146 if (scale == 2)
147 return FloatType::getF64(ctx);
148 return FloatType();
149}
150
151unsigned FloatType::getFPMantissaWidth() {
152 return APFloat::semanticsPrecision(getFloatSemantics());
153}
154
155//===----------------------------------------------------------------------===//
156// FunctionType
157//===----------------------------------------------------------------------===//
158
159unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
160
161ArrayRef<Type> FunctionType::getInputs() const {
162 return getImpl()->getInputs();
163}
164
165unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
166
167ArrayRef<Type> FunctionType::getResults() const {
168 return getImpl()->getResults();
169}
170
171FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
172 return get(getContext(), inputs, results);
173}
174
175/// Returns a new function type with the specified arguments and results
176/// inserted.
177FunctionType FunctionType::getWithArgsAndResults(
178 ArrayRef<unsigned> argIndices, TypeRange argTypes,
179 ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
180 SmallVector<Type> argStorage, resultStorage;
181 TypeRange newArgTypes =
182 insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
183 TypeRange newResultTypes =
184 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
185 return clone(newArgTypes, newResultTypes);
186}
187
188/// Returns a new function type without the specified arguments and results.
189FunctionType
190FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
191 const BitVector &resultIndices) {
192 SmallVector<Type> argStorage, resultStorage;
193 TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
194 TypeRange newResultTypes =
195 filterTypesOut(getResults(), resultIndices, resultStorage);
196 return clone(newArgTypes, newResultTypes);
197}
198
199//===----------------------------------------------------------------------===//
200// OpaqueType
201//===----------------------------------------------------------------------===//
202
203/// Verify the construction of an opaque type.
204LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
205 StringAttr dialect, StringRef typeData) {
206 if (!Dialect::isValidNamespace(dialect.strref()))
207 return emitError() << "invalid dialect namespace '" << dialect << "'";
208
209 // Check that the dialect is actually registered.
210 MLIRContext *context = dialect.getContext();
211 if (!context->allowsUnregisteredDialects() &&
212 !context->getLoadedDialect(dialect.strref())) {
213 return emitError()
214 << "`!" << dialect << "<\"" << typeData << "\">"
215 << "` type created with unregistered dialect. If this is "
216 "intended, please call allowUnregisteredDialects() on the "
217 "MLIRContext, or use -allow-unregistered-dialect with "
218 "the MLIR opt tool used";
219 }
220
221 return success();
222}
223
224//===----------------------------------------------------------------------===//
225// VectorType
226//===----------------------------------------------------------------------===//
227
228LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
229 ArrayRef<int64_t> shape, Type elementType,
230 ArrayRef<bool> scalableDims) {
231 if (!isValidElementType(elementType))
232 return emitError()
233 << "vector elements must be int/index/float type but got "
234 << elementType;
235
236 if (any_of(shape, [](int64_t i) { return i <= 0; }))
237 return emitError()
238 << "vector types must have positive constant sizes but got "
239 << shape;
240
241 if (scalableDims.size() != shape.size())
242 return emitError() << "number of dims must match, got "
243 << scalableDims.size() << " and " << shape.size();
244
245 return success();
246}
247
248VectorType VectorType::scaleElementBitwidth(unsigned scale) {
249 if (!scale)
250 return VectorType();
251 if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
252 if (auto scaledEt = et.scaleElementBitwidth(scale))
253 return VectorType::get(getShape(), scaledEt, getScalableDims());
254 if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
255 if (auto scaledEt = et.scaleElementBitwidth(scale))
256 return VectorType::get(getShape(), scaledEt, getScalableDims());
257 return VectorType();
258}
259
260VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
261 Type elementType) const {
262 return VectorType::get(shape.value_or(getShape()), elementType,
263 getScalableDims());
264}
265
266//===----------------------------------------------------------------------===//
267// TensorType
268//===----------------------------------------------------------------------===//
269
270Type TensorType::getElementType() const {
271 return llvm::TypeSwitch<TensorType, Type>(*this)
272 .Case<RankedTensorType, UnrankedTensorType>(
273 [](auto type) { return type.getElementType(); });
274}
275
276bool TensorType::hasRank() const { return !llvm::isa<UnrankedTensorType>(*this); }
277
278ArrayRef<int64_t> TensorType::getShape() const {
279 return llvm::cast<RankedTensorType>(*this).getShape();
280}
281
282TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
283 Type elementType) const {
284 if (llvm::dyn_cast<UnrankedTensorType>(*this)) {
285 if (shape)
286 return RankedTensorType::get(*shape, elementType);
287 return UnrankedTensorType::get(elementType);
288 }
289
290 auto rankedTy = llvm::cast<RankedTensorType>(*this);
291 if (!shape)
292 return RankedTensorType::get(rankedTy.getShape(), elementType,
293 rankedTy.getEncoding());
294 return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
295 rankedTy.getEncoding());
296}
297
298RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
299 Type elementType) const {
300 return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
301}
302
303RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
304 return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType: getElementType()));
305}
306
307// Check if "elementType" can be an element type of a tensor.
308static LogicalResult
309checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
310 Type elementType) {
311 if (!TensorType::isValidElementType(type: elementType))
312 return emitError() << "invalid tensor element type: " << elementType;
313 return success();
314}
315
316/// Return true if the specified element type is ok in a tensor.
317bool TensorType::isValidElementType(Type type) {
318 // Note: Non standard/builtin types are allowed to exist within tensor
319 // types. Dialects are expected to verify that tensor types have a valid
320 // element type within that dialect.
321 return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
322 IndexType>(type) ||
323 !llvm::isa<BuiltinDialect>(type.getDialect());
324}
325
326//===----------------------------------------------------------------------===//
327// RankedTensorType
328//===----------------------------------------------------------------------===//
329
330LogicalResult
331RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
332 ArrayRef<int64_t> shape, Type elementType,
333 Attribute encoding) {
334 for (int64_t s : shape)
335 if (s < 0 && !ShapedType::isDynamic(s))
336 return emitError() << "invalid tensor dimension size";
337 if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
338 if (failed(v.verifyEncoding(shape, elementType, emitError)))
339 return failure();
340 return checkTensorElementType(emitError, elementType);
341}
342
343//===----------------------------------------------------------------------===//
344// UnrankedTensorType
345//===----------------------------------------------------------------------===//
346
347LogicalResult
348UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
349 Type elementType) {
350 return checkTensorElementType(emitError, elementType);
351}
352
353//===----------------------------------------------------------------------===//
354// BaseMemRefType
355//===----------------------------------------------------------------------===//
356
357Type BaseMemRefType::getElementType() const {
358 return llvm::TypeSwitch<BaseMemRefType, Type>(*this)
359 .Case<MemRefType, UnrankedMemRefType>(
360 [](auto type) { return type.getElementType(); });
361}
362
363bool BaseMemRefType::hasRank() const { return !llvm::isa<UnrankedMemRefType>(*this); }
364
365ArrayRef<int64_t> BaseMemRefType::getShape() const {
366 return llvm::cast<MemRefType>(*this).getShape();
367}
368
369BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
370 Type elementType) const {
371 if (llvm::dyn_cast<UnrankedMemRefType>(*this)) {
372 if (!shape)
373 return UnrankedMemRefType::get(elementType, getMemorySpace());
374 MemRefType::Builder builder(*shape, elementType);
375 builder.setMemorySpace(getMemorySpace());
376 return builder;
377 }
378
379 MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
380 if (shape)
381 builder.setShape(*shape);
382 builder.setElementType(elementType);
383 return builder;
384}
385
386MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape,
387 Type elementType) const {
388 return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
389}
390
391MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
392 return ::llvm::cast<MemRefType>(cloneWith(shape, elementType: getElementType()));
393}
394
395Attribute BaseMemRefType::getMemorySpace() const {
396 if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
397 return rankedMemRefTy.getMemorySpace();
398 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
399}
400
401unsigned BaseMemRefType::getMemorySpaceAsInt() const {
402 if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
403 return rankedMemRefTy.getMemorySpaceAsInt();
404 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
405}
406
407//===----------------------------------------------------------------------===//
408// MemRefType
409//===----------------------------------------------------------------------===//
410
411/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
412/// `originalShape` with some `1` entries erased, return the set of indices
413/// that specifies which of the entries of `originalShape` are dropped to obtain
414/// `reducedShape`. The returned mask can be applied as a projection to
415/// `originalShape` to obtain the `reducedShape`. This mask is useful to track
416/// which dimensions must be kept when e.g. compute MemRef strides under
417/// rank-reducing operations. Return std::nullopt if reducedShape cannot be
418/// obtained by dropping only `1` entries in `originalShape`.
419std::optional<llvm::SmallDenseSet<unsigned>>
420mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
421 ArrayRef<int64_t> reducedShape) {
422 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
423 llvm::SmallDenseSet<unsigned> unusedDims;
424 unsigned reducedIdx = 0;
425 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
426 // Greedily insert `originalIdx` if match.
427 if (reducedIdx < reducedRank &&
428 originalShape[originalIdx] == reducedShape[reducedIdx]) {
429 reducedIdx++;
430 continue;
431 }
432
433 unusedDims.insert(V: originalIdx);
434 // If no match on `originalIdx`, the `originalShape` at this dimension
435 // must be 1, otherwise we bail.
436 if (originalShape[originalIdx] != 1)
437 return std::nullopt;
438 }
439 // The whole reducedShape must be scanned, otherwise we bail.
440 if (reducedIdx != reducedRank)
441 return std::nullopt;
442 return unusedDims;
443}
444
445SliceVerificationResult
446mlir::isRankReducedType(ShapedType originalType,
447 ShapedType candidateReducedType) {
448 if (originalType == candidateReducedType)
449 return SliceVerificationResult::Success;
450
451 ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
452 ShapedType candidateReducedShapedType =
453 llvm::cast<ShapedType>(candidateReducedType);
454
455 // Rank and size logic is valid for all ShapedTypes.
456 ArrayRef<int64_t> originalShape = originalShapedType.getShape();
457 ArrayRef<int64_t> candidateReducedShape =
458 candidateReducedShapedType.getShape();
459 unsigned originalRank = originalShape.size(),
460 candidateReducedRank = candidateReducedShape.size();
461 if (candidateReducedRank > originalRank)
462 return SliceVerificationResult::RankTooLarge;
463
464 auto optionalUnusedDimsMask =
465 computeRankReductionMask(originalShape, candidateReducedShape);
466
467 // Sizes cannot be matched in case empty vector is returned.
468 if (!optionalUnusedDimsMask)
469 return SliceVerificationResult::SizeMismatch;
470
471 if (originalShapedType.getElementType() !=
472 candidateReducedShapedType.getElementType())
473 return SliceVerificationResult::ElemTypeMismatch;
474
475 return SliceVerificationResult::Success;
476}
477
478bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
479 // Empty attribute is allowed as default memory space.
480 if (!memorySpace)
481 return true;
482
483 // Supported built-in attributes.
484 if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
485 return true;
486
487 // Allow custom dialect attributes.
488 if (!isa<BuiltinDialect>(Val: memorySpace.getDialect()))
489 return true;
490
491 return false;
492}
493
494Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
495 MLIRContext *ctx) {
496 if (memorySpace == 0)
497 return nullptr;
498
499 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
500}
501
502Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
503 IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
504 if (intMemorySpace && intMemorySpace.getValue() == 0)
505 return nullptr;
506
507 return memorySpace;
508}
509
510unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
511 if (!memorySpace)
512 return 0;
513
514 assert(llvm::isa<IntegerAttr>(memorySpace) &&
515 "Using `getMemorySpaceInteger` with non-Integer attribute");
516
517 return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
518}
519
520unsigned MemRefType::getMemorySpaceAsInt() const {
521 return detail::getMemorySpaceAsInt(getMemorySpace());
522}
523
524MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
525 MemRefLayoutAttrInterface layout,
526 Attribute memorySpace) {
527 // Use default layout for empty attribute.
528 if (!layout)
529 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
530 shape.size(), elementType.getContext()));
531
532 // Drop default memory space value and replace it with empty attribute.
533 memorySpace = skipDefaultMemorySpace(memorySpace);
534
535 return Base::get(elementType.getContext(), shape, elementType, layout,
536 memorySpace);
537}
538
539MemRefType MemRefType::getChecked(
540 function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
541 Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
542
543 // Use default layout for empty attribute.
544 if (!layout)
545 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
546 shape.size(), elementType.getContext()));
547
548 // Drop default memory space value and replace it with empty attribute.
549 memorySpace = skipDefaultMemorySpace(memorySpace);
550
551 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
552 elementType, layout, memorySpace);
553}
554
555MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
556 AffineMap map, Attribute memorySpace) {
557
558 // Use default layout for empty map.
559 if (!map)
560 map = AffineMap::getMultiDimIdentityMap(shape.size(),
561 elementType.getContext());
562
563 // Wrap AffineMap into Attribute.
564 auto layout = AffineMapAttr::get(map);
565
566 // Drop default memory space value and replace it with empty attribute.
567 memorySpace = skipDefaultMemorySpace(memorySpace);
568
569 return Base::get(elementType.getContext(), shape, elementType, layout,
570 memorySpace);
571}
572
573MemRefType
574MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
575 ArrayRef<int64_t> shape, Type elementType, AffineMap map,
576 Attribute memorySpace) {
577
578 // Use default layout for empty map.
579 if (!map)
580 map = AffineMap::getMultiDimIdentityMap(shape.size(),
581 elementType.getContext());
582
583 // Wrap AffineMap into Attribute.
584 auto layout = AffineMapAttr::get(map);
585
586 // Drop default memory space value and replace it with empty attribute.
587 memorySpace = skipDefaultMemorySpace(memorySpace);
588
589 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
590 elementType, layout, memorySpace);
591}
592
593MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
594 AffineMap map, unsigned memorySpaceInd) {
595
596 // Use default layout for empty map.
597 if (!map)
598 map = AffineMap::getMultiDimIdentityMap(shape.size(),
599 elementType.getContext());
600
601 // Wrap AffineMap into Attribute.
602 auto layout = AffineMapAttr::get(map);
603
604 // Convert deprecated integer-like memory space to Attribute.
605 Attribute memorySpace =
606 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
607
608 return Base::get(elementType.getContext(), shape, elementType, layout,
609 memorySpace);
610}
611
612MemRefType
613MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
614 ArrayRef<int64_t> shape, Type elementType, AffineMap map,
615 unsigned memorySpaceInd) {
616
617 // Use default layout for empty map.
618 if (!map)
619 map = AffineMap::getMultiDimIdentityMap(shape.size(),
620 elementType.getContext());
621
622 // Wrap AffineMap into Attribute.
623 auto layout = AffineMapAttr::get(map);
624
625 // Convert deprecated integer-like memory space to Attribute.
626 Attribute memorySpace =
627 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
628
629 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
630 elementType, layout, memorySpace);
631}
632
633LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
634 ArrayRef<int64_t> shape, Type elementType,
635 MemRefLayoutAttrInterface layout,
636 Attribute memorySpace) {
637 if (!BaseMemRefType::isValidElementType(elementType))
638 return emitError() << "invalid memref element type";
639
640 // Negative sizes are not allowed except for `kDynamic`.
641 for (int64_t s : shape)
642 if (s < 0 && !ShapedType::isDynamic(s))
643 return emitError() << "invalid memref size";
644
645 assert(layout && "missing layout specification");
646 if (failed(layout.verifyLayout(shape, emitError)))
647 return failure();
648
649 if (!isSupportedMemorySpace(memorySpace))
650 return emitError() << "unsupported memory space Attribute";
651
652 return success();
653}
654
655//===----------------------------------------------------------------------===//
656// UnrankedMemRefType
657//===----------------------------------------------------------------------===//
658
659unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
660 return detail::getMemorySpaceAsInt(getMemorySpace());
661}
662
663LogicalResult
664UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
665 Type elementType, Attribute memorySpace) {
666 if (!BaseMemRefType::isValidElementType(elementType))
667 return emitError() << "invalid memref element type";
668
669 if (!isSupportedMemorySpace(memorySpace))
670 return emitError() << "unsupported memory space Attribute";
671
672 return success();
673}
674
675// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
676// i.e. single term). Accumulate the AffineExpr into the existing one.
677static void extractStridesFromTerm(AffineExpr e,
678 AffineExpr multiplicativeFactor,
679 MutableArrayRef<AffineExpr> strides,
680 AffineExpr &offset) {
681 if (auto dim = dyn_cast<AffineDimExpr>(Val&: e))
682 strides[dim.getPosition()] =
683 strides[dim.getPosition()] + multiplicativeFactor;
684 else
685 offset = offset + e * multiplicativeFactor;
686}
687
688/// Takes a single AffineExpr `e` and populates the `strides` array with the
689/// strides expressions for each dim position.
690/// The convention is that the strides for dimensions d0, .. dn appear in
691/// order to make indexing intuitive into the result.
692static LogicalResult extractStrides(AffineExpr e,
693 AffineExpr multiplicativeFactor,
694 MutableArrayRef<AffineExpr> strides,
695 AffineExpr &offset) {
696 auto bin = dyn_cast<AffineBinaryOpExpr>(Val&: e);
697 if (!bin) {
698 extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
699 return success();
700 }
701
702 if (bin.getKind() == AffineExprKind::CeilDiv ||
703 bin.getKind() == AffineExprKind::FloorDiv ||
704 bin.getKind() == AffineExprKind::Mod)
705 return failure();
706
707 if (bin.getKind() == AffineExprKind::Mul) {
708 auto dim = dyn_cast<AffineDimExpr>(Val: bin.getLHS());
709 if (dim) {
710 strides[dim.getPosition()] =
711 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
712 return success();
713 }
714 // LHS and RHS may both contain complex expressions of dims. Try one path
715 // and if it fails try the other. This is guaranteed to succeed because
716 // only one path may have a `dim`, otherwise this is not an AffineExpr in
717 // the first place.
718 if (bin.getLHS().isSymbolicOrConstant())
719 return extractStrides(e: bin.getRHS(), multiplicativeFactor: multiplicativeFactor * bin.getLHS(),
720 strides, offset);
721 return extractStrides(e: bin.getLHS(), multiplicativeFactor: multiplicativeFactor * bin.getRHS(),
722 strides, offset);
723 }
724
725 if (bin.getKind() == AffineExprKind::Add) {
726 auto res1 =
727 extractStrides(e: bin.getLHS(), multiplicativeFactor, strides, offset);
728 auto res2 =
729 extractStrides(e: bin.getRHS(), multiplicativeFactor, strides, offset);
730 return success(isSuccess: succeeded(result: res1) && succeeded(result: res2));
731 }
732
733 llvm_unreachable("unexpected binary operation");
734}
735
736/// A stride specification is a list of integer values that are either static
737/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
738/// the distance in the number of elements between successive entries along a
739/// particular dimension.
740///
741/// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
742/// non-contiguous memory region of `42` by `16` `f32` elements in which the
743/// distance between two consecutive elements along the outer dimension is `1`
744/// and the distance between two consecutive elements along the inner dimension
745/// is `64`.
746///
747/// The convention is that the strides for dimensions d0, .. dn appear in
748/// order to make indexing intuitive into the result.
749static LogicalResult getStridesAndOffset(MemRefType t,
750 SmallVectorImpl<AffineExpr> &strides,
751 AffineExpr &offset) {
752 AffineMap m = t.getLayout().getAffineMap();
753
754 if (m.getNumResults() != 1 && !m.isIdentity())
755 return failure();
756
757 auto zero = getAffineConstantExpr(0, t.getContext());
758 auto one = getAffineConstantExpr(1, t.getContext());
759 offset = zero;
760 strides.assign(t.getRank(), zero);
761
762 // Canonical case for empty map.
763 if (m.isIdentity()) {
764 // 0-D corner case, offset is already 0.
765 if (t.getRank() == 0)
766 return success();
767 auto stridedExpr =
768 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
769 if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
770 return success();
771 assert(false && "unexpected failure: extract strides in canonical layout");
772 }
773
774 // Non-canonical case requires more work.
775 auto stridedExpr =
776 simplifyAffineExpr(expr: m.getResult(idx: 0), numDims: m.getNumDims(), numSymbols: m.getNumSymbols());
777 if (failed(extractStrides(stridedExpr, one, strides, offset))) {
778 offset = AffineExpr();
779 strides.clear();
780 return failure();
781 }
782
783 // Simplify results to allow folding to constants and simple checks.
784 unsigned numDims = m.getNumDims();
785 unsigned numSymbols = m.getNumSymbols();
786 offset = simplifyAffineExpr(expr: offset, numDims, numSymbols);
787 for (auto &stride : strides)
788 stride = simplifyAffineExpr(expr: stride, numDims, numSymbols);
789
790 // In practice, a strided memref must be internally non-aliasing. Test
791 // against 0 as a proxy.
792 // TODO: static cases can have more advanced checks.
793 // TODO: dynamic cases would require a way to compare symbolic
794 // expressions and would probably need an affine set context propagated
795 // everywhere.
796 if (llvm::any_of(Range&: strides, P: [](AffineExpr e) {
797 return e == getAffineConstantExpr(constant: 0, context: e.getContext());
798 })) {
799 offset = AffineExpr();
800 strides.clear();
801 return failure();
802 }
803
804 return success();
805}
806
807LogicalResult mlir::getStridesAndOffset(MemRefType t,
808 SmallVectorImpl<int64_t> &strides,
809 int64_t &offset) {
810 // Happy path: the type uses the strided layout directly.
811 if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(t.getLayout())) {
812 llvm::append_range(strides, strided.getStrides());
813 offset = strided.getOffset();
814 return success();
815 }
816
817 // Otherwise, defer to the affine fallback as layouts are supposed to be
818 // convertible to affine maps.
819 AffineExpr offsetExpr;
820 SmallVector<AffineExpr, 4> strideExprs;
821 if (failed(::getStridesAndOffset(t: t, strides&: strideExprs, offset&: offsetExpr)))
822 return failure();
823 if (auto cst = dyn_cast<AffineConstantExpr>(Val&: offsetExpr))
824 offset = cst.getValue();
825 else
826 offset = ShapedType::kDynamic;
827 for (auto e : strideExprs) {
828 if (auto c = dyn_cast<AffineConstantExpr>(Val&: e))
829 strides.push_back(Elt: c.getValue());
830 else
831 strides.push_back(ShapedType::kDynamic);
832 }
833 return success();
834}
835
836std::pair<SmallVector<int64_t>, int64_t>
837mlir::getStridesAndOffset(MemRefType t) {
838 SmallVector<int64_t> strides;
839 int64_t offset;
840 LogicalResult status = getStridesAndOffset(t, strides, offset);
841 (void)status;
842 assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
843 return {strides, offset};
844}
845
846//===----------------------------------------------------------------------===//
847/// TupleType
848//===----------------------------------------------------------------------===//
849
850/// Return the elements types for this tuple.
851ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
852
853/// Accumulate the types contained in this tuple and tuples nested within it.
854/// Note that this only flattens nested tuples, not any other container type,
855/// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
856/// (i32, tensor<i32>, f32, i64)
857void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
858 for (Type type : getTypes()) {
859 if (auto nestedTuple = llvm::dyn_cast<TupleType>(type))
860 nestedTuple.getFlattenedTypes(types);
861 else
862 types.push_back(type);
863 }
864}
865
866/// Return the number of element types.
867size_t TupleType::size() const { return getImpl()->size(); }
868
869//===----------------------------------------------------------------------===//
870// Type Utilities
871//===----------------------------------------------------------------------===//
872
873/// Return a version of `t` with identity layout if it can be determined
874/// statically that the layout is the canonical contiguous strided layout.
875/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
876/// `t` with simplified layout.
877/// If `t` has multiple layout maps or a multi-result layout, just return `t`.
878MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
879 AffineMap m = t.getLayout().getAffineMap();
880
881 // Already in canonical form.
882 if (m.isIdentity())
883 return t;
884
885 // Can't reduce to canonical identity form, return in canonical form.
886 if (m.getNumResults() > 1)
887 return t;
888
889 // Corner-case for 0-D affine maps.
890 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
891 if (auto cst = dyn_cast<AffineConstantExpr>(m.getResult(0)))
892 if (cst.getValue() == 0)
893 return MemRefType::Builder(t).setLayout({});
894 return t;
895 }
896
897 // 0-D corner case for empty shape that still have an affine map. Example:
898 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
899 // offset needs to remain, just return t.
900 if (t.getShape().empty())
901 return t;
902
903 // If the canonical strided layout for the sizes of `t` is equal to the
904 // simplified layout of `t` we can just return an empty layout. Otherwise,
905 // just simplify the existing layout.
906 AffineExpr expr =
907 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
908 auto simplifiedLayoutExpr =
909 simplifyAffineExpr(expr: m.getResult(idx: 0), numDims: m.getNumDims(), numSymbols: m.getNumSymbols());
910 if (expr != simplifiedLayoutExpr)
911 return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
912 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
913 return MemRefType::Builder(t).setLayout({});
914}
915
916AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
917 ArrayRef<AffineExpr> exprs,
918 MLIRContext *context) {
919 // Size 0 corner case is useful for canonicalizations.
920 if (sizes.empty())
921 return getAffineConstantExpr(constant: 0, context);
922
923 assert(!exprs.empty() && "expected exprs");
924 auto maps = AffineMap::inferFromExprList(exprsList: exprs, context);
925 assert(!maps.empty() && "Expected one non-empty map");
926 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
927
928 AffineExpr expr;
929 bool dynamicPoisonBit = false;
930 int64_t runningSize = 1;
931 for (auto en : llvm::zip(t: llvm::reverse(C&: exprs), u: llvm::reverse(C&: sizes))) {
932 int64_t size = std::get<1>(t&: en);
933 AffineExpr dimExpr = std::get<0>(t&: en);
934 AffineExpr stride = dynamicPoisonBit
935 ? getAffineSymbolExpr(position: nSymbols++, context)
936 : getAffineConstantExpr(constant: runningSize, context);
937 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
938 if (size > 0) {
939 runningSize *= size;
940 assert(runningSize > 0 && "integer overflow in size computation");
941 } else {
942 dynamicPoisonBit = true;
943 }
944 }
945 return simplifyAffineExpr(expr, numDims, numSymbols: nSymbols);
946}
947
948AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
949 MLIRContext *context) {
950 SmallVector<AffineExpr, 4> exprs;
951 exprs.reserve(N: sizes.size());
952 for (auto dim : llvm::seq<unsigned>(Begin: 0, End: sizes.size()))
953 exprs.push_back(Elt: getAffineDimExpr(position: dim, context));
954 return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
955}
956
957bool mlir::isStrided(MemRefType t) {
958 int64_t offset;
959 SmallVector<int64_t, 4> strides;
960 auto res = getStridesAndOffset(t, strides, offset);
961 return succeeded(res);
962}
963
964bool mlir::isLastMemrefDimUnitStride(MemRefType type) {
965 int64_t offset;
966 SmallVector<int64_t> strides;
967 auto successStrides = getStridesAndOffset(type, strides, offset);
968 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
969}
970
971bool mlir::trailingNDimsContiguous(MemRefType type, int64_t n) {
972 if (!isLastMemrefDimUnitStride(type))
973 return false;
974
975 auto memrefShape = type.getShape().take_back(n);
976 if (ShapedType::isDynamicShape(memrefShape))
977 return false;
978
979 if (type.getLayout().isIdentity())
980 return true;
981
982 int64_t offset;
983 SmallVector<int64_t> stridesFull;
984 if (!succeeded(getStridesAndOffset(type, stridesFull, offset)))
985 return false;
986 auto strides = ArrayRef<int64_t>(stridesFull).take_back(N: n);
987
988 if (strides.empty())
989 return true;
990
991 // Check whether strides match "flattened" dims.
992 SmallVector<int64_t> flattenedDims;
993 auto dimProduct = 1;
994 for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
995 dimProduct *= dim;
996 flattenedDims.push_back(dimProduct);
997 }
998
999 strides = strides.drop_back(N: 1);
1000 return llvm::equal(LRange&: strides, RRange: llvm::reverse(C&: flattenedDims));
1001}
1002

source code of mlir/lib/IR/BuiltinTypes.cpp