1 | //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/IR/BuiltinTypes.h" |
10 | #include "TypeDetail.h" |
11 | #include "mlir/IR/AffineExpr.h" |
12 | #include "mlir/IR/AffineMap.h" |
13 | #include "mlir/IR/BuiltinAttributes.h" |
14 | #include "mlir/IR/BuiltinDialect.h" |
15 | #include "mlir/IR/Diagnostics.h" |
16 | #include "mlir/IR/Dialect.h" |
17 | #include "mlir/IR/TensorEncoding.h" |
18 | #include "mlir/IR/TypeUtilities.h" |
19 | #include "llvm/ADT/APFloat.h" |
20 | #include "llvm/ADT/BitVector.h" |
21 | #include "llvm/ADT/Sequence.h" |
22 | #include "llvm/ADT/Twine.h" |
23 | #include "llvm/ADT/TypeSwitch.h" |
24 | |
25 | using namespace mlir; |
26 | using namespace mlir::detail; |
27 | |
28 | //===----------------------------------------------------------------------===// |
29 | /// Tablegen Type Definitions |
30 | //===----------------------------------------------------------------------===// |
31 | |
32 | #define GET_TYPEDEF_CLASSES |
33 | #include "mlir/IR/BuiltinTypes.cpp.inc" |
34 | |
35 | namespace mlir { |
36 | #include "mlir/IR/BuiltinTypeConstraints.cpp.inc" |
37 | } // namespace mlir |
38 | |
39 | //===----------------------------------------------------------------------===// |
40 | // BuiltinDialect |
41 | //===----------------------------------------------------------------------===// |
42 | |
43 | void BuiltinDialect::registerTypes() { |
44 | addTypes< |
45 | #define GET_TYPEDEF_LIST |
46 | #include "mlir/IR/BuiltinTypes.cpp.inc" |
47 | >(); |
48 | } |
49 | |
50 | //===----------------------------------------------------------------------===// |
51 | /// ComplexType |
52 | //===----------------------------------------------------------------------===// |
53 | |
54 | /// Verify the construction of an integer type. |
55 | LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError, |
56 | Type elementType) { |
57 | if (!elementType.isIntOrFloat()) |
58 | return emitError() << "invalid element type for complex" ; |
59 | return success(); |
60 | } |
61 | |
62 | //===----------------------------------------------------------------------===// |
63 | // Integer Type |
64 | //===----------------------------------------------------------------------===// |
65 | |
66 | /// Verify the construction of an integer type. |
67 | LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError, |
68 | unsigned width, |
69 | SignednessSemantics signedness) { |
70 | if (width > IntegerType::kMaxWidth) { |
71 | return emitError() << "integer bitwidth is limited to " |
72 | << IntegerType::kMaxWidth << " bits" ; |
73 | } |
74 | return success(); |
75 | } |
76 | |
77 | unsigned IntegerType::getWidth() const { return getImpl()->width; } |
78 | |
79 | IntegerType::SignednessSemantics IntegerType::getSignedness() const { |
80 | return getImpl()->signedness; |
81 | } |
82 | |
83 | IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { |
84 | if (!scale) |
85 | return IntegerType(); |
86 | return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); |
87 | } |
88 | |
89 | //===----------------------------------------------------------------------===// |
90 | // Float Types |
91 | //===----------------------------------------------------------------------===// |
92 | |
93 | // Mapping from MLIR FloatType to APFloat semantics. |
94 | #define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \ |
95 | const llvm::fltSemantics &TYPE::getFloatSemantics() const { \ |
96 | return APFloat::SEM(); \ |
97 | } |
98 | FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN) |
99 | FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN) |
100 | FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN) |
101 | FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2) |
102 | FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3) |
103 | FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN) |
104 | FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ) |
105 | FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ) |
106 | FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ) |
107 | FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4) |
108 | FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU) |
109 | FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat) |
110 | FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf) |
111 | FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32) |
112 | FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle) |
113 | FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble) |
114 | FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended) |
115 | FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad) |
116 | #undef FLOAT_TYPE_SEMANTICS |
117 | |
118 | FloatType Float16Type::scaleElementBitwidth(unsigned scale) const { |
119 | if (scale == 2) |
120 | return Float32Type::get(getContext()); |
121 | if (scale == 4) |
122 | return Float64Type::get(getContext()); |
123 | return FloatType(); |
124 | } |
125 | |
126 | FloatType BFloat16Type::scaleElementBitwidth(unsigned scale) const { |
127 | if (scale == 2) |
128 | return Float32Type::get(getContext()); |
129 | if (scale == 4) |
130 | return Float64Type::get(getContext()); |
131 | return FloatType(); |
132 | } |
133 | |
134 | FloatType Float32Type::scaleElementBitwidth(unsigned scale) const { |
135 | if (scale == 2) |
136 | return Float64Type::get(getContext()); |
137 | return FloatType(); |
138 | } |
139 | |
140 | //===----------------------------------------------------------------------===// |
141 | // FunctionType |
142 | //===----------------------------------------------------------------------===// |
143 | |
144 | unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } |
145 | |
146 | ArrayRef<Type> FunctionType::getInputs() const { |
147 | return getImpl()->getInputs(); |
148 | } |
149 | |
150 | unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } |
151 | |
152 | ArrayRef<Type> FunctionType::getResults() const { |
153 | return getImpl()->getResults(); |
154 | } |
155 | |
156 | FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const { |
157 | return get(getContext(), inputs, results); |
158 | } |
159 | |
160 | /// Returns a new function type with the specified arguments and results |
161 | /// inserted. |
162 | FunctionType FunctionType::getWithArgsAndResults( |
163 | ArrayRef<unsigned> argIndices, TypeRange argTypes, |
164 | ArrayRef<unsigned> resultIndices, TypeRange resultTypes) { |
165 | SmallVector<Type> argStorage, resultStorage; |
166 | TypeRange newArgTypes = |
167 | insertTypesInto(getInputs(), argIndices, argTypes, argStorage); |
168 | TypeRange newResultTypes = |
169 | insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage); |
170 | return clone(newArgTypes, newResultTypes); |
171 | } |
172 | |
173 | /// Returns a new function type without the specified arguments and results. |
174 | FunctionType |
175 | FunctionType::getWithoutArgsAndResults(const BitVector &argIndices, |
176 | const BitVector &resultIndices) { |
177 | SmallVector<Type> argStorage, resultStorage; |
178 | TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage); |
179 | TypeRange newResultTypes = |
180 | filterTypesOut(getResults(), resultIndices, resultStorage); |
181 | return clone(newArgTypes, newResultTypes); |
182 | } |
183 | |
184 | //===----------------------------------------------------------------------===// |
185 | // OpaqueType |
186 | //===----------------------------------------------------------------------===// |
187 | |
188 | /// Verify the construction of an opaque type. |
189 | LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError, |
190 | StringAttr dialect, StringRef typeData) { |
191 | if (!Dialect::isValidNamespace(dialect.strref())) |
192 | return emitError() << "invalid dialect namespace '" << dialect << "'" ; |
193 | |
194 | // Check that the dialect is actually registered. |
195 | MLIRContext *context = dialect.getContext(); |
196 | if (!context->allowsUnregisteredDialects() && |
197 | !context->getLoadedDialect(dialect.strref())) { |
198 | return emitError() |
199 | << "`!" << dialect << "<\"" << typeData << "\">" |
200 | << "` type created with unregistered dialect. If this is " |
201 | "intended, please call allowUnregisteredDialects() on the " |
202 | "MLIRContext, or use -allow-unregistered-dialect with " |
203 | "the MLIR opt tool used" ; |
204 | } |
205 | |
206 | return success(); |
207 | } |
208 | |
209 | //===----------------------------------------------------------------------===// |
210 | // VectorType |
211 | //===----------------------------------------------------------------------===// |
212 | |
213 | bool VectorType::isValidElementType(Type t) { |
214 | return isValidVectorTypeElementType(t); |
215 | } |
216 | |
217 | LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError, |
218 | ArrayRef<int64_t> shape, Type elementType, |
219 | ArrayRef<bool> scalableDims) { |
220 | if (!isValidElementType(elementType)) |
221 | return emitError() |
222 | << "vector elements must be int/index/float type but got " |
223 | << elementType; |
224 | |
225 | if (any_of(shape, [](int64_t i) { return i <= 0; })) |
226 | return emitError() |
227 | << "vector types must have positive constant sizes but got " |
228 | << shape; |
229 | |
230 | if (scalableDims.size() != shape.size()) |
231 | return emitError() << "number of dims must match, got " |
232 | << scalableDims.size() << " and " << shape.size(); |
233 | |
234 | return success(); |
235 | } |
236 | |
237 | VectorType VectorType::scaleElementBitwidth(unsigned scale) { |
238 | if (!scale) |
239 | return VectorType(); |
240 | if (auto et = llvm::dyn_cast<IntegerType>(getElementType())) |
241 | if (auto scaledEt = et.scaleElementBitwidth(scale)) |
242 | return VectorType::get(getShape(), scaledEt, getScalableDims()); |
243 | if (auto et = llvm::dyn_cast<FloatType>(getElementType())) |
244 | if (auto scaledEt = et.scaleElementBitwidth(scale)) |
245 | return VectorType::get(getShape(), scaledEt, getScalableDims()); |
246 | return VectorType(); |
247 | } |
248 | |
249 | VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape, |
250 | Type elementType) const { |
251 | return VectorType::get(shape.value_or(getShape()), elementType, |
252 | getScalableDims()); |
253 | } |
254 | |
255 | //===----------------------------------------------------------------------===// |
256 | // TensorType |
257 | //===----------------------------------------------------------------------===// |
258 | |
259 | Type TensorType::getElementType() const { |
260 | return llvm::TypeSwitch<TensorType, Type>(*this) |
261 | .Case<RankedTensorType, UnrankedTensorType>( |
262 | [](auto type) { return type.getElementType(); }); |
263 | } |
264 | |
265 | bool TensorType::hasRank() const { |
266 | return !llvm::isa<UnrankedTensorType>(Val: *this); |
267 | } |
268 | |
269 | ArrayRef<int64_t> TensorType::getShape() const { |
270 | return llvm::cast<RankedTensorType>(*this).getShape(); |
271 | } |
272 | |
273 | TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape, |
274 | Type elementType) const { |
275 | if (llvm::dyn_cast<UnrankedTensorType>(*this)) { |
276 | if (shape) |
277 | return RankedTensorType::get(*shape, elementType); |
278 | return UnrankedTensorType::get(elementType); |
279 | } |
280 | |
281 | auto rankedTy = llvm::cast<RankedTensorType>(*this); |
282 | if (!shape) |
283 | return RankedTensorType::get(rankedTy.getShape(), elementType, |
284 | rankedTy.getEncoding()); |
285 | return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType, |
286 | rankedTy.getEncoding()); |
287 | } |
288 | |
289 | RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape, |
290 | Type elementType) const { |
291 | return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType)); |
292 | } |
293 | |
294 | RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const { |
295 | return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType: getElementType())); |
296 | } |
297 | |
298 | // Check if "elementType" can be an element type of a tensor. |
299 | static LogicalResult |
300 | checkTensorElementType(function_ref<InFlightDiagnostic()> emitError, |
301 | Type elementType) { |
302 | if (!TensorType::isValidElementType(type: elementType)) |
303 | return emitError() << "invalid tensor element type: " << elementType; |
304 | return success(); |
305 | } |
306 | |
307 | /// Return true if the specified element type is ok in a tensor. |
308 | bool TensorType::isValidElementType(Type type) { |
309 | // Note: Non standard/builtin types are allowed to exist within tensor |
310 | // types. Dialects are expected to verify that tensor types have a valid |
311 | // element type within that dialect. |
312 | return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, |
313 | IndexType>(type) || |
314 | !llvm::isa<BuiltinDialect>(type.getDialect()); |
315 | } |
316 | |
317 | //===----------------------------------------------------------------------===// |
318 | // RankedTensorType |
319 | //===----------------------------------------------------------------------===// |
320 | |
321 | LogicalResult |
322 | RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, |
323 | ArrayRef<int64_t> shape, Type elementType, |
324 | Attribute encoding) { |
325 | for (int64_t s : shape) |
326 | if (s < 0 && !ShapedType::isDynamic(s)) |
327 | return emitError() << "invalid tensor dimension size" ; |
328 | if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) |
329 | if (failed(v.verifyEncoding(shape, elementType, emitError))) |
330 | return failure(); |
331 | return checkTensorElementType(emitError, elementType); |
332 | } |
333 | |
334 | //===----------------------------------------------------------------------===// |
335 | // UnrankedTensorType |
336 | //===----------------------------------------------------------------------===// |
337 | |
338 | LogicalResult |
339 | UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, |
340 | Type elementType) { |
341 | return checkTensorElementType(emitError, elementType); |
342 | } |
343 | |
344 | //===----------------------------------------------------------------------===// |
345 | // BaseMemRefType |
346 | //===----------------------------------------------------------------------===// |
347 | |
348 | Type BaseMemRefType::getElementType() const { |
349 | return llvm::TypeSwitch<BaseMemRefType, Type>(*this) |
350 | .Case<MemRefType, UnrankedMemRefType>( |
351 | [](auto type) { return type.getElementType(); }); |
352 | } |
353 | |
354 | bool BaseMemRefType::hasRank() const { |
355 | return !llvm::isa<UnrankedMemRefType>(*this); |
356 | } |
357 | |
358 | ArrayRef<int64_t> BaseMemRefType::getShape() const { |
359 | return llvm::cast<MemRefType>(*this).getShape(); |
360 | } |
361 | |
362 | BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape, |
363 | Type elementType) const { |
364 | if (llvm::dyn_cast<UnrankedMemRefType>(*this)) { |
365 | if (!shape) |
366 | return UnrankedMemRefType::get(elementType, getMemorySpace()); |
367 | MemRefType::Builder builder(*shape, elementType); |
368 | builder.setMemorySpace(getMemorySpace()); |
369 | return builder; |
370 | } |
371 | |
372 | MemRefType::Builder builder(llvm::cast<MemRefType>(*this)); |
373 | if (shape) |
374 | builder.setShape(*shape); |
375 | builder.setElementType(elementType); |
376 | return builder; |
377 | } |
378 | |
379 | MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape, |
380 | Type elementType) const { |
381 | return ::llvm::cast<MemRefType>(cloneWith(shape, elementType)); |
382 | } |
383 | |
384 | MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const { |
385 | return ::llvm::cast<MemRefType>(cloneWith(shape, elementType: getElementType())); |
386 | } |
387 | |
388 | Attribute BaseMemRefType::getMemorySpace() const { |
389 | if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this)) |
390 | return rankedMemRefTy.getMemorySpace(); |
391 | return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace(); |
392 | } |
393 | |
394 | unsigned BaseMemRefType::getMemorySpaceAsInt() const { |
395 | if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this)) |
396 | return rankedMemRefTy.getMemorySpaceAsInt(); |
397 | return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt(); |
398 | } |
399 | |
400 | //===----------------------------------------------------------------------===// |
401 | // MemRefType |
402 | //===----------------------------------------------------------------------===// |
403 | |
404 | std::optional<llvm::SmallDenseSet<unsigned>> |
405 | mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape, |
406 | ArrayRef<int64_t> reducedShape, |
407 | bool matchDynamic) { |
408 | size_t originalRank = originalShape.size(), reducedRank = reducedShape.size(); |
409 | llvm::SmallDenseSet<unsigned> unusedDims; |
410 | unsigned reducedIdx = 0; |
411 | for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { |
412 | // Greedily insert `originalIdx` if match. |
413 | int64_t origSize = originalShape[originalIdx]; |
414 | // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1. |
415 | if (matchDynamic && reducedIdx < reducedRank && origSize != 1 && |
416 | (ShapedType::isDynamic(reducedShape[reducedIdx]) || |
417 | ShapedType::isDynamic(origSize))) { |
418 | reducedIdx++; |
419 | continue; |
420 | } |
421 | if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) { |
422 | reducedIdx++; |
423 | continue; |
424 | } |
425 | |
426 | unusedDims.insert(V: originalIdx); |
427 | // If no match on `originalIdx`, the `originalShape` at this dimension |
428 | // must be 1, otherwise we bail. |
429 | if (origSize != 1) |
430 | return std::nullopt; |
431 | } |
432 | // The whole reducedShape must be scanned, otherwise we bail. |
433 | if (reducedIdx != reducedRank) |
434 | return std::nullopt; |
435 | return unusedDims; |
436 | } |
437 | |
438 | SliceVerificationResult |
439 | mlir::isRankReducedType(ShapedType originalType, |
440 | ShapedType candidateReducedType) { |
441 | if (originalType == candidateReducedType) |
442 | return SliceVerificationResult::Success; |
443 | |
444 | ShapedType originalShapedType = llvm::cast<ShapedType>(originalType); |
445 | ShapedType candidateReducedShapedType = |
446 | llvm::cast<ShapedType>(candidateReducedType); |
447 | |
448 | // Rank and size logic is valid for all ShapedTypes. |
449 | ArrayRef<int64_t> originalShape = originalShapedType.getShape(); |
450 | ArrayRef<int64_t> candidateReducedShape = |
451 | candidateReducedShapedType.getShape(); |
452 | unsigned originalRank = originalShape.size(), |
453 | candidateReducedRank = candidateReducedShape.size(); |
454 | if (candidateReducedRank > originalRank) |
455 | return SliceVerificationResult::RankTooLarge; |
456 | |
457 | auto optionalUnusedDimsMask = |
458 | computeRankReductionMask(originalShape, candidateReducedShape); |
459 | |
460 | // Sizes cannot be matched in case empty vector is returned. |
461 | if (!optionalUnusedDimsMask) |
462 | return SliceVerificationResult::SizeMismatch; |
463 | |
464 | if (originalShapedType.getElementType() != |
465 | candidateReducedShapedType.getElementType()) |
466 | return SliceVerificationResult::ElemTypeMismatch; |
467 | |
468 | return SliceVerificationResult::Success; |
469 | } |
470 | |
471 | bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) { |
472 | // Empty attribute is allowed as default memory space. |
473 | if (!memorySpace) |
474 | return true; |
475 | |
476 | // Supported built-in attributes. |
477 | if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace)) |
478 | return true; |
479 | |
480 | // Allow custom dialect attributes. |
481 | if (!isa<BuiltinDialect>(Val: memorySpace.getDialect())) |
482 | return true; |
483 | |
484 | return false; |
485 | } |
486 | |
487 | Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace, |
488 | MLIRContext *ctx) { |
489 | if (memorySpace == 0) |
490 | return nullptr; |
491 | |
492 | return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace); |
493 | } |
494 | |
495 | Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) { |
496 | IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace); |
497 | if (intMemorySpace && intMemorySpace.getValue() == 0) |
498 | return nullptr; |
499 | |
500 | return memorySpace; |
501 | } |
502 | |
503 | unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) { |
504 | if (!memorySpace) |
505 | return 0; |
506 | |
507 | assert(llvm::isa<IntegerAttr>(memorySpace) && |
508 | "Using `getMemorySpaceInteger` with non-Integer attribute" ); |
509 | |
510 | return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt()); |
511 | } |
512 | |
513 | unsigned MemRefType::getMemorySpaceAsInt() const { |
514 | return detail::getMemorySpaceAsInt(getMemorySpace()); |
515 | } |
516 | |
517 | MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, |
518 | MemRefLayoutAttrInterface layout, |
519 | Attribute memorySpace) { |
520 | // Use default layout for empty attribute. |
521 | if (!layout) |
522 | layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( |
523 | shape.size(), elementType.getContext())); |
524 | |
525 | // Drop default memory space value and replace it with empty attribute. |
526 | memorySpace = skipDefaultMemorySpace(memorySpace); |
527 | |
528 | return Base::get(elementType.getContext(), shape, elementType, layout, |
529 | memorySpace); |
530 | } |
531 | |
532 | MemRefType MemRefType::getChecked( |
533 | function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape, |
534 | Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) { |
535 | |
536 | // Use default layout for empty attribute. |
537 | if (!layout) |
538 | layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( |
539 | shape.size(), elementType.getContext())); |
540 | |
541 | // Drop default memory space value and replace it with empty attribute. |
542 | memorySpace = skipDefaultMemorySpace(memorySpace); |
543 | |
544 | return Base::getChecked(emitErrorFn, elementType.getContext(), shape, |
545 | elementType, layout, memorySpace); |
546 | } |
547 | |
548 | MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, |
549 | AffineMap map, Attribute memorySpace) { |
550 | |
551 | // Use default layout for empty map. |
552 | if (!map) |
553 | map = AffineMap::getMultiDimIdentityMap(shape.size(), |
554 | elementType.getContext()); |
555 | |
556 | // Wrap AffineMap into Attribute. |
557 | auto layout = AffineMapAttr::get(map); |
558 | |
559 | // Drop default memory space value and replace it with empty attribute. |
560 | memorySpace = skipDefaultMemorySpace(memorySpace); |
561 | |
562 | return Base::get(elementType.getContext(), shape, elementType, layout, |
563 | memorySpace); |
564 | } |
565 | |
566 | MemRefType |
567 | MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, |
568 | ArrayRef<int64_t> shape, Type elementType, AffineMap map, |
569 | Attribute memorySpace) { |
570 | |
571 | // Use default layout for empty map. |
572 | if (!map) |
573 | map = AffineMap::getMultiDimIdentityMap(shape.size(), |
574 | elementType.getContext()); |
575 | |
576 | // Wrap AffineMap into Attribute. |
577 | auto layout = AffineMapAttr::get(map); |
578 | |
579 | // Drop default memory space value and replace it with empty attribute. |
580 | memorySpace = skipDefaultMemorySpace(memorySpace); |
581 | |
582 | return Base::getChecked(emitErrorFn, elementType.getContext(), shape, |
583 | elementType, layout, memorySpace); |
584 | } |
585 | |
586 | MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, |
587 | AffineMap map, unsigned memorySpaceInd) { |
588 | |
589 | // Use default layout for empty map. |
590 | if (!map) |
591 | map = AffineMap::getMultiDimIdentityMap(shape.size(), |
592 | elementType.getContext()); |
593 | |
594 | // Wrap AffineMap into Attribute. |
595 | auto layout = AffineMapAttr::get(map); |
596 | |
597 | // Convert deprecated integer-like memory space to Attribute. |
598 | Attribute memorySpace = |
599 | wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); |
600 | |
601 | return Base::get(elementType.getContext(), shape, elementType, layout, |
602 | memorySpace); |
603 | } |
604 | |
605 | MemRefType |
606 | MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, |
607 | ArrayRef<int64_t> shape, Type elementType, AffineMap map, |
608 | unsigned memorySpaceInd) { |
609 | |
610 | // Use default layout for empty map. |
611 | if (!map) |
612 | map = AffineMap::getMultiDimIdentityMap(shape.size(), |
613 | elementType.getContext()); |
614 | |
615 | // Wrap AffineMap into Attribute. |
616 | auto layout = AffineMapAttr::get(map); |
617 | |
618 | // Convert deprecated integer-like memory space to Attribute. |
619 | Attribute memorySpace = |
620 | wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); |
621 | |
622 | return Base::getChecked(emitErrorFn, elementType.getContext(), shape, |
623 | elementType, layout, memorySpace); |
624 | } |
625 | |
626 | LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError, |
627 | ArrayRef<int64_t> shape, Type elementType, |
628 | MemRefLayoutAttrInterface layout, |
629 | Attribute memorySpace) { |
630 | if (!BaseMemRefType::isValidElementType(elementType)) |
631 | return emitError() << "invalid memref element type" ; |
632 | |
633 | // Negative sizes are not allowed except for `kDynamic`. |
634 | for (int64_t s : shape) |
635 | if (s < 0 && !ShapedType::isDynamic(s)) |
636 | return emitError() << "invalid memref size" ; |
637 | |
638 | assert(layout && "missing layout specification" ); |
639 | if (failed(layout.verifyLayout(shape, emitError))) |
640 | return failure(); |
641 | |
642 | if (!isSupportedMemorySpace(memorySpace)) |
643 | return emitError() << "unsupported memory space Attribute" ; |
644 | |
645 | return success(); |
646 | } |
647 | |
648 | bool MemRefType::areTrailingDimsContiguous(int64_t n) { |
649 | if (!isLastDimUnitStride()) |
650 | return false; |
651 | |
652 | auto memrefShape = getShape().take_back(n); |
653 | if (ShapedType::isDynamicShape(memrefShape)) |
654 | return false; |
655 | |
656 | if (getLayout().isIdentity()) |
657 | return true; |
658 | |
659 | int64_t offset; |
660 | SmallVector<int64_t> stridesFull; |
661 | if (!succeeded(getStridesAndOffset(stridesFull, offset))) |
662 | return false; |
663 | auto strides = ArrayRef<int64_t>(stridesFull).take_back(n); |
664 | |
665 | if (strides.empty()) |
666 | return true; |
667 | |
668 | // Check whether strides match "flattened" dims. |
669 | SmallVector<int64_t> flattenedDims; |
670 | auto dimProduct = 1; |
671 | for (auto dim : llvm::reverse(memrefShape.drop_front(1))) { |
672 | dimProduct *= dim; |
673 | flattenedDims.push_back(dimProduct); |
674 | } |
675 | |
676 | strides = strides.drop_back(1); |
677 | return llvm::equal(strides, llvm::reverse(flattenedDims)); |
678 | } |
679 | |
680 | MemRefType MemRefType::canonicalizeStridedLayout() { |
681 | AffineMap m = getLayout().getAffineMap(); |
682 | |
683 | // Already in canonical form. |
684 | if (m.isIdentity()) |
685 | return *this; |
686 | |
687 | // Can't reduce to canonical identity form, return in canonical form. |
688 | if (m.getNumResults() > 1) |
689 | return *this; |
690 | |
691 | // Corner-case for 0-D affine maps. |
692 | if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { |
693 | if (auto cst = llvm::dyn_cast<AffineConstantExpr>(m.getResult(0))) |
694 | if (cst.getValue() == 0) |
695 | return MemRefType::Builder(*this).setLayout({}); |
696 | return *this; |
697 | } |
698 | |
699 | // 0-D corner case for empty shape that still have an affine map. Example: |
700 | // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose |
701 | // offset needs to remain, just return t. |
702 | if (getShape().empty()) |
703 | return *this; |
704 | |
705 | // If the canonical strided layout for the sizes of `t` is equal to the |
706 | // simplified layout of `t` we can just return an empty layout. Otherwise, |
707 | // just simplify the existing layout. |
708 | AffineExpr expr = makeCanonicalStridedLayoutExpr(getShape(), getContext()); |
709 | auto simplifiedLayoutExpr = |
710 | simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); |
711 | if (expr != simplifiedLayoutExpr) |
712 | return MemRefType::Builder(*this).setLayout( |
713 | AffineMapAttr::get(AffineMap::get(m.getNumDims(), m.getNumSymbols(), |
714 | simplifiedLayoutExpr))); |
715 | return MemRefType::Builder(*this).setLayout({}); |
716 | } |
717 | |
718 | LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides, |
719 | int64_t &offset) { |
720 | return getLayout().getStridesAndOffset(getShape(), strides, offset); |
721 | } |
722 | |
723 | std::pair<SmallVector<int64_t>, int64_t> MemRefType::getStridesAndOffset() { |
724 | SmallVector<int64_t> strides; |
725 | int64_t offset; |
726 | LogicalResult status = getStridesAndOffset(strides, offset); |
727 | (void)status; |
728 | assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset" ); |
729 | return {strides, offset}; |
730 | } |
731 | |
732 | bool MemRefType::isStrided() { |
733 | int64_t offset; |
734 | SmallVector<int64_t, 4> strides; |
735 | auto res = getStridesAndOffset(strides, offset); |
736 | return succeeded(res); |
737 | } |
738 | |
739 | bool MemRefType::isLastDimUnitStride() { |
740 | int64_t offset; |
741 | SmallVector<int64_t> strides; |
742 | auto successStrides = getStridesAndOffset(strides, offset); |
743 | return succeeded(successStrides) && (strides.empty() || strides.back() == 1); |
744 | } |
745 | |
746 | //===----------------------------------------------------------------------===// |
747 | // UnrankedMemRefType |
748 | //===----------------------------------------------------------------------===// |
749 | |
750 | unsigned UnrankedMemRefType::getMemorySpaceAsInt() const { |
751 | return detail::getMemorySpaceAsInt(getMemorySpace()); |
752 | } |
753 | |
754 | LogicalResult |
755 | UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError, |
756 | Type elementType, Attribute memorySpace) { |
757 | if (!BaseMemRefType::isValidElementType(elementType)) |
758 | return emitError() << "invalid memref element type" ; |
759 | |
760 | if (!isSupportedMemorySpace(memorySpace)) |
761 | return emitError() << "unsupported memory space Attribute" ; |
762 | |
763 | return success(); |
764 | } |
765 | |
766 | //===----------------------------------------------------------------------===// |
767 | /// TupleType |
768 | //===----------------------------------------------------------------------===// |
769 | |
770 | /// Return the elements types for this tuple. |
771 | ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } |
772 | |
773 | /// Accumulate the types contained in this tuple and tuples nested within it. |
774 | /// Note that this only flattens nested tuples, not any other container type, |
775 | /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to |
776 | /// (i32, tensor<i32>, f32, i64) |
777 | void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { |
778 | for (Type type : getTypes()) { |
779 | if (auto nestedTuple = llvm::dyn_cast<TupleType>(type)) |
780 | nestedTuple.getFlattenedTypes(types); |
781 | else |
782 | types.push_back(type); |
783 | } |
784 | } |
785 | |
786 | /// Return the number of element types. |
787 | size_t TupleType::size() const { return getImpl()->size(); } |
788 | |
789 | //===----------------------------------------------------------------------===// |
790 | // Type Utilities |
791 | //===----------------------------------------------------------------------===// |
792 | |
793 | AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, |
794 | ArrayRef<AffineExpr> exprs, |
795 | MLIRContext *context) { |
796 | // Size 0 corner case is useful for canonicalizations. |
797 | if (sizes.empty()) |
798 | return getAffineConstantExpr(constant: 0, context); |
799 | |
800 | assert(!exprs.empty() && "expected exprs" ); |
801 | auto maps = AffineMap::inferFromExprList(exprsList: exprs, context); |
802 | assert(!maps.empty() && "Expected one non-empty map" ); |
803 | unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); |
804 | |
805 | AffineExpr expr; |
806 | bool dynamicPoisonBit = false; |
807 | int64_t runningSize = 1; |
808 | for (auto en : llvm::zip(t: llvm::reverse(C&: exprs), u: llvm::reverse(C&: sizes))) { |
809 | int64_t size = std::get<1>(t&: en); |
810 | AffineExpr dimExpr = std::get<0>(t&: en); |
811 | AffineExpr stride = dynamicPoisonBit |
812 | ? getAffineSymbolExpr(position: nSymbols++, context) |
813 | : getAffineConstantExpr(constant: runningSize, context); |
814 | expr = expr ? expr + dimExpr * stride : dimExpr * stride; |
815 | if (size > 0) { |
816 | runningSize *= size; |
817 | assert(runningSize > 0 && "integer overflow in size computation" ); |
818 | } else { |
819 | dynamicPoisonBit = true; |
820 | } |
821 | } |
822 | return simplifyAffineExpr(expr, numDims, numSymbols: nSymbols); |
823 | } |
824 | |
825 | AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, |
826 | MLIRContext *context) { |
827 | SmallVector<AffineExpr, 4> exprs; |
828 | exprs.reserve(N: sizes.size()); |
829 | for (auto dim : llvm::seq<unsigned>(Begin: 0, End: sizes.size())) |
830 | exprs.push_back(Elt: getAffineDimExpr(position: dim, context)); |
831 | return makeCanonicalStridedLayoutExpr(sizes, exprs, context); |
832 | } |
833 | |