1 | //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/IR/BuiltinTypes.h" |
10 | #include "TypeDetail.h" |
11 | #include "mlir/IR/AffineExpr.h" |
12 | #include "mlir/IR/AffineMap.h" |
13 | #include "mlir/IR/BuiltinAttributes.h" |
14 | #include "mlir/IR/BuiltinDialect.h" |
15 | #include "mlir/IR/Diagnostics.h" |
16 | #include "mlir/IR/Dialect.h" |
17 | #include "mlir/IR/TensorEncoding.h" |
18 | #include "mlir/IR/TypeUtilities.h" |
19 | #include "llvm/ADT/APFloat.h" |
20 | #include "llvm/ADT/BitVector.h" |
21 | #include "llvm/ADT/Sequence.h" |
22 | #include "llvm/ADT/Twine.h" |
23 | #include "llvm/ADT/TypeSwitch.h" |
24 | |
25 | using namespace mlir; |
26 | using namespace mlir::detail; |
27 | |
28 | //===----------------------------------------------------------------------===// |
29 | /// Tablegen Type Definitions |
30 | //===----------------------------------------------------------------------===// |
31 | |
32 | #define GET_TYPEDEF_CLASSES |
33 | #include "mlir/IR/BuiltinTypes.cpp.inc" |
34 | |
35 | //===----------------------------------------------------------------------===// |
36 | // BuiltinDialect |
37 | //===----------------------------------------------------------------------===// |
38 | |
39 | void BuiltinDialect::registerTypes() { |
40 | addTypes< |
41 | #define GET_TYPEDEF_LIST |
42 | #include "mlir/IR/BuiltinTypes.cpp.inc" |
43 | >(); |
44 | } |
45 | |
46 | //===----------------------------------------------------------------------===// |
47 | /// ComplexType |
48 | //===----------------------------------------------------------------------===// |
49 | |
50 | /// Verify the construction of an integer type. |
51 | LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError, |
52 | Type elementType) { |
53 | if (!elementType.isIntOrFloat()) |
54 | return emitError() << "invalid element type for complex" ; |
55 | return success(); |
56 | } |
57 | |
58 | //===----------------------------------------------------------------------===// |
59 | // Integer Type |
60 | //===----------------------------------------------------------------------===// |
61 | |
62 | /// Verify the construction of an integer type. |
63 | LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError, |
64 | unsigned width, |
65 | SignednessSemantics signedness) { |
66 | if (width > IntegerType::kMaxWidth) { |
67 | return emitError() << "integer bitwidth is limited to " |
68 | << IntegerType::kMaxWidth << " bits" ; |
69 | } |
70 | return success(); |
71 | } |
72 | |
73 | unsigned IntegerType::getWidth() const { return getImpl()->width; } |
74 | |
75 | IntegerType::SignednessSemantics IntegerType::getSignedness() const { |
76 | return getImpl()->signedness; |
77 | } |
78 | |
79 | IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { |
80 | if (!scale) |
81 | return IntegerType(); |
82 | return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); |
83 | } |
84 | |
85 | //===----------------------------------------------------------------------===// |
86 | // Float Type |
87 | //===----------------------------------------------------------------------===// |
88 | |
89 | unsigned FloatType::getWidth() { |
90 | if (llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType, |
91 | Float8E4M3FNUZType, Float8E4M3B11FNUZType>(*this)) |
92 | return 8; |
93 | if (llvm::isa<Float16Type, BFloat16Type>(*this)) |
94 | return 16; |
95 | if (llvm::isa<Float32Type, FloatTF32Type>(*this)) |
96 | return 32; |
97 | if (llvm::isa<Float64Type>(*this)) |
98 | return 64; |
99 | if (llvm::isa<Float80Type>(*this)) |
100 | return 80; |
101 | if (llvm::isa<Float128Type>(*this)) |
102 | return 128; |
103 | llvm_unreachable("unexpected float type" ); |
104 | } |
105 | |
106 | /// Returns the floating semantics for the given type. |
107 | const llvm::fltSemantics &FloatType::getFloatSemantics() { |
108 | if (llvm::isa<Float8E5M2Type>(*this)) |
109 | return APFloat::Float8E5M2(); |
110 | if (llvm::isa<Float8E4M3FNType>(*this)) |
111 | return APFloat::Float8E4M3FN(); |
112 | if (llvm::isa<Float8E5M2FNUZType>(*this)) |
113 | return APFloat::Float8E5M2FNUZ(); |
114 | if (llvm::isa<Float8E4M3FNUZType>(*this)) |
115 | return APFloat::Float8E4M3FNUZ(); |
116 | if (llvm::isa<Float8E4M3B11FNUZType>(*this)) |
117 | return APFloat::Float8E4M3B11FNUZ(); |
118 | if (llvm::isa<BFloat16Type>(*this)) |
119 | return APFloat::BFloat(); |
120 | if (llvm::isa<Float16Type>(*this)) |
121 | return APFloat::IEEEhalf(); |
122 | if (llvm::isa<FloatTF32Type>(*this)) |
123 | return APFloat::FloatTF32(); |
124 | if (llvm::isa<Float32Type>(*this)) |
125 | return APFloat::IEEEsingle(); |
126 | if (llvm::isa<Float64Type>(*this)) |
127 | return APFloat::IEEEdouble(); |
128 | if (llvm::isa<Float80Type>(*this)) |
129 | return APFloat::x87DoubleExtended(); |
130 | if (llvm::isa<Float128Type>(*this)) |
131 | return APFloat::IEEEquad(); |
132 | llvm_unreachable("non-floating point type used" ); |
133 | } |
134 | |
135 | FloatType FloatType::scaleElementBitwidth(unsigned scale) { |
136 | if (!scale) |
137 | return FloatType(); |
138 | MLIRContext *ctx = getContext(); |
139 | if (isF16() || isBF16()) { |
140 | if (scale == 2) |
141 | return FloatType::getF32(ctx); |
142 | if (scale == 4) |
143 | return FloatType::getF64(ctx); |
144 | } |
145 | if (isF32()) |
146 | if (scale == 2) |
147 | return FloatType::getF64(ctx); |
148 | return FloatType(); |
149 | } |
150 | |
151 | unsigned FloatType::getFPMantissaWidth() { |
152 | return APFloat::semanticsPrecision(getFloatSemantics()); |
153 | } |
154 | |
155 | //===----------------------------------------------------------------------===// |
156 | // FunctionType |
157 | //===----------------------------------------------------------------------===// |
158 | |
159 | unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } |
160 | |
161 | ArrayRef<Type> FunctionType::getInputs() const { |
162 | return getImpl()->getInputs(); |
163 | } |
164 | |
165 | unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } |
166 | |
167 | ArrayRef<Type> FunctionType::getResults() const { |
168 | return getImpl()->getResults(); |
169 | } |
170 | |
171 | FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const { |
172 | return get(getContext(), inputs, results); |
173 | } |
174 | |
175 | /// Returns a new function type with the specified arguments and results |
176 | /// inserted. |
177 | FunctionType FunctionType::getWithArgsAndResults( |
178 | ArrayRef<unsigned> argIndices, TypeRange argTypes, |
179 | ArrayRef<unsigned> resultIndices, TypeRange resultTypes) { |
180 | SmallVector<Type> argStorage, resultStorage; |
181 | TypeRange newArgTypes = |
182 | insertTypesInto(getInputs(), argIndices, argTypes, argStorage); |
183 | TypeRange newResultTypes = |
184 | insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage); |
185 | return clone(newArgTypes, newResultTypes); |
186 | } |
187 | |
188 | /// Returns a new function type without the specified arguments and results. |
189 | FunctionType |
190 | FunctionType::getWithoutArgsAndResults(const BitVector &argIndices, |
191 | const BitVector &resultIndices) { |
192 | SmallVector<Type> argStorage, resultStorage; |
193 | TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage); |
194 | TypeRange newResultTypes = |
195 | filterTypesOut(getResults(), resultIndices, resultStorage); |
196 | return clone(newArgTypes, newResultTypes); |
197 | } |
198 | |
199 | //===----------------------------------------------------------------------===// |
200 | // OpaqueType |
201 | //===----------------------------------------------------------------------===// |
202 | |
203 | /// Verify the construction of an opaque type. |
204 | LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError, |
205 | StringAttr dialect, StringRef typeData) { |
206 | if (!Dialect::isValidNamespace(dialect.strref())) |
207 | return emitError() << "invalid dialect namespace '" << dialect << "'" ; |
208 | |
209 | // Check that the dialect is actually registered. |
210 | MLIRContext *context = dialect.getContext(); |
211 | if (!context->allowsUnregisteredDialects() && |
212 | !context->getLoadedDialect(dialect.strref())) { |
213 | return emitError() |
214 | << "`!" << dialect << "<\"" << typeData << "\">" |
215 | << "` type created with unregistered dialect. If this is " |
216 | "intended, please call allowUnregisteredDialects() on the " |
217 | "MLIRContext, or use -allow-unregistered-dialect with " |
218 | "the MLIR opt tool used" ; |
219 | } |
220 | |
221 | return success(); |
222 | } |
223 | |
224 | //===----------------------------------------------------------------------===// |
225 | // VectorType |
226 | //===----------------------------------------------------------------------===// |
227 | |
228 | LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError, |
229 | ArrayRef<int64_t> shape, Type elementType, |
230 | ArrayRef<bool> scalableDims) { |
231 | if (!isValidElementType(elementType)) |
232 | return emitError() |
233 | << "vector elements must be int/index/float type but got " |
234 | << elementType; |
235 | |
236 | if (any_of(shape, [](int64_t i) { return i <= 0; })) |
237 | return emitError() |
238 | << "vector types must have positive constant sizes but got " |
239 | << shape; |
240 | |
241 | if (scalableDims.size() != shape.size()) |
242 | return emitError() << "number of dims must match, got " |
243 | << scalableDims.size() << " and " << shape.size(); |
244 | |
245 | return success(); |
246 | } |
247 | |
248 | VectorType VectorType::scaleElementBitwidth(unsigned scale) { |
249 | if (!scale) |
250 | return VectorType(); |
251 | if (auto et = llvm::dyn_cast<IntegerType>(getElementType())) |
252 | if (auto scaledEt = et.scaleElementBitwidth(scale)) |
253 | return VectorType::get(getShape(), scaledEt, getScalableDims()); |
254 | if (auto et = llvm::dyn_cast<FloatType>(getElementType())) |
255 | if (auto scaledEt = et.scaleElementBitwidth(scale)) |
256 | return VectorType::get(getShape(), scaledEt, getScalableDims()); |
257 | return VectorType(); |
258 | } |
259 | |
260 | VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape, |
261 | Type elementType) const { |
262 | return VectorType::get(shape.value_or(getShape()), elementType, |
263 | getScalableDims()); |
264 | } |
265 | |
266 | //===----------------------------------------------------------------------===// |
267 | // TensorType |
268 | //===----------------------------------------------------------------------===// |
269 | |
270 | Type TensorType::getElementType() const { |
271 | return llvm::TypeSwitch<TensorType, Type>(*this) |
272 | .Case<RankedTensorType, UnrankedTensorType>( |
273 | [](auto type) { return type.getElementType(); }); |
274 | } |
275 | |
276 | bool TensorType::hasRank() const { return !llvm::isa<UnrankedTensorType>(*this); } |
277 | |
278 | ArrayRef<int64_t> TensorType::getShape() const { |
279 | return llvm::cast<RankedTensorType>(*this).getShape(); |
280 | } |
281 | |
282 | TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape, |
283 | Type elementType) const { |
284 | if (llvm::dyn_cast<UnrankedTensorType>(*this)) { |
285 | if (shape) |
286 | return RankedTensorType::get(*shape, elementType); |
287 | return UnrankedTensorType::get(elementType); |
288 | } |
289 | |
290 | auto rankedTy = llvm::cast<RankedTensorType>(*this); |
291 | if (!shape) |
292 | return RankedTensorType::get(rankedTy.getShape(), elementType, |
293 | rankedTy.getEncoding()); |
294 | return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType, |
295 | rankedTy.getEncoding()); |
296 | } |
297 | |
298 | RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape, |
299 | Type elementType) const { |
300 | return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType)); |
301 | } |
302 | |
303 | RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const { |
304 | return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType: getElementType())); |
305 | } |
306 | |
307 | // Check if "elementType" can be an element type of a tensor. |
308 | static LogicalResult |
309 | checkTensorElementType(function_ref<InFlightDiagnostic()> emitError, |
310 | Type elementType) { |
311 | if (!TensorType::isValidElementType(type: elementType)) |
312 | return emitError() << "invalid tensor element type: " << elementType; |
313 | return success(); |
314 | } |
315 | |
316 | /// Return true if the specified element type is ok in a tensor. |
317 | bool TensorType::isValidElementType(Type type) { |
318 | // Note: Non standard/builtin types are allowed to exist within tensor |
319 | // types. Dialects are expected to verify that tensor types have a valid |
320 | // element type within that dialect. |
321 | return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, |
322 | IndexType>(type) || |
323 | !llvm::isa<BuiltinDialect>(type.getDialect()); |
324 | } |
325 | |
326 | //===----------------------------------------------------------------------===// |
327 | // RankedTensorType |
328 | //===----------------------------------------------------------------------===// |
329 | |
330 | LogicalResult |
331 | RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, |
332 | ArrayRef<int64_t> shape, Type elementType, |
333 | Attribute encoding) { |
334 | for (int64_t s : shape) |
335 | if (s < 0 && !ShapedType::isDynamic(s)) |
336 | return emitError() << "invalid tensor dimension size" ; |
337 | if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) |
338 | if (failed(v.verifyEncoding(shape, elementType, emitError))) |
339 | return failure(); |
340 | return checkTensorElementType(emitError, elementType); |
341 | } |
342 | |
343 | //===----------------------------------------------------------------------===// |
344 | // UnrankedTensorType |
345 | //===----------------------------------------------------------------------===// |
346 | |
347 | LogicalResult |
348 | UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, |
349 | Type elementType) { |
350 | return checkTensorElementType(emitError, elementType); |
351 | } |
352 | |
353 | //===----------------------------------------------------------------------===// |
354 | // BaseMemRefType |
355 | //===----------------------------------------------------------------------===// |
356 | |
357 | Type BaseMemRefType::getElementType() const { |
358 | return llvm::TypeSwitch<BaseMemRefType, Type>(*this) |
359 | .Case<MemRefType, UnrankedMemRefType>( |
360 | [](auto type) { return type.getElementType(); }); |
361 | } |
362 | |
363 | bool BaseMemRefType::hasRank() const { return !llvm::isa<UnrankedMemRefType>(*this); } |
364 | |
365 | ArrayRef<int64_t> BaseMemRefType::getShape() const { |
366 | return llvm::cast<MemRefType>(*this).getShape(); |
367 | } |
368 | |
369 | BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape, |
370 | Type elementType) const { |
371 | if (llvm::dyn_cast<UnrankedMemRefType>(*this)) { |
372 | if (!shape) |
373 | return UnrankedMemRefType::get(elementType, getMemorySpace()); |
374 | MemRefType::Builder builder(*shape, elementType); |
375 | builder.setMemorySpace(getMemorySpace()); |
376 | return builder; |
377 | } |
378 | |
379 | MemRefType::Builder builder(llvm::cast<MemRefType>(*this)); |
380 | if (shape) |
381 | builder.setShape(*shape); |
382 | builder.setElementType(elementType); |
383 | return builder; |
384 | } |
385 | |
386 | MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape, |
387 | Type elementType) const { |
388 | return ::llvm::cast<MemRefType>(cloneWith(shape, elementType)); |
389 | } |
390 | |
391 | MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const { |
392 | return ::llvm::cast<MemRefType>(cloneWith(shape, elementType: getElementType())); |
393 | } |
394 | |
395 | Attribute BaseMemRefType::getMemorySpace() const { |
396 | if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this)) |
397 | return rankedMemRefTy.getMemorySpace(); |
398 | return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace(); |
399 | } |
400 | |
401 | unsigned BaseMemRefType::getMemorySpaceAsInt() const { |
402 | if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this)) |
403 | return rankedMemRefTy.getMemorySpaceAsInt(); |
404 | return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt(); |
405 | } |
406 | |
407 | //===----------------------------------------------------------------------===// |
408 | // MemRefType |
409 | //===----------------------------------------------------------------------===// |
410 | |
411 | /// Given an `originalShape` and a `reducedShape` assumed to be a subset of |
412 | /// `originalShape` with some `1` entries erased, return the set of indices |
413 | /// that specifies which of the entries of `originalShape` are dropped to obtain |
414 | /// `reducedShape`. The returned mask can be applied as a projection to |
415 | /// `originalShape` to obtain the `reducedShape`. This mask is useful to track |
416 | /// which dimensions must be kept when e.g. compute MemRef strides under |
417 | /// rank-reducing operations. Return std::nullopt if reducedShape cannot be |
418 | /// obtained by dropping only `1` entries in `originalShape`. |
419 | std::optional<llvm::SmallDenseSet<unsigned>> |
420 | mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape, |
421 | ArrayRef<int64_t> reducedShape) { |
422 | size_t originalRank = originalShape.size(), reducedRank = reducedShape.size(); |
423 | llvm::SmallDenseSet<unsigned> unusedDims; |
424 | unsigned reducedIdx = 0; |
425 | for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { |
426 | // Greedily insert `originalIdx` if match. |
427 | if (reducedIdx < reducedRank && |
428 | originalShape[originalIdx] == reducedShape[reducedIdx]) { |
429 | reducedIdx++; |
430 | continue; |
431 | } |
432 | |
433 | unusedDims.insert(V: originalIdx); |
434 | // If no match on `originalIdx`, the `originalShape` at this dimension |
435 | // must be 1, otherwise we bail. |
436 | if (originalShape[originalIdx] != 1) |
437 | return std::nullopt; |
438 | } |
439 | // The whole reducedShape must be scanned, otherwise we bail. |
440 | if (reducedIdx != reducedRank) |
441 | return std::nullopt; |
442 | return unusedDims; |
443 | } |
444 | |
445 | SliceVerificationResult |
446 | mlir::isRankReducedType(ShapedType originalType, |
447 | ShapedType candidateReducedType) { |
448 | if (originalType == candidateReducedType) |
449 | return SliceVerificationResult::Success; |
450 | |
451 | ShapedType originalShapedType = llvm::cast<ShapedType>(originalType); |
452 | ShapedType candidateReducedShapedType = |
453 | llvm::cast<ShapedType>(candidateReducedType); |
454 | |
455 | // Rank and size logic is valid for all ShapedTypes. |
456 | ArrayRef<int64_t> originalShape = originalShapedType.getShape(); |
457 | ArrayRef<int64_t> candidateReducedShape = |
458 | candidateReducedShapedType.getShape(); |
459 | unsigned originalRank = originalShape.size(), |
460 | candidateReducedRank = candidateReducedShape.size(); |
461 | if (candidateReducedRank > originalRank) |
462 | return SliceVerificationResult::RankTooLarge; |
463 | |
464 | auto optionalUnusedDimsMask = |
465 | computeRankReductionMask(originalShape, candidateReducedShape); |
466 | |
467 | // Sizes cannot be matched in case empty vector is returned. |
468 | if (!optionalUnusedDimsMask) |
469 | return SliceVerificationResult::SizeMismatch; |
470 | |
471 | if (originalShapedType.getElementType() != |
472 | candidateReducedShapedType.getElementType()) |
473 | return SliceVerificationResult::ElemTypeMismatch; |
474 | |
475 | return SliceVerificationResult::Success; |
476 | } |
477 | |
478 | bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) { |
479 | // Empty attribute is allowed as default memory space. |
480 | if (!memorySpace) |
481 | return true; |
482 | |
483 | // Supported built-in attributes. |
484 | if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace)) |
485 | return true; |
486 | |
487 | // Allow custom dialect attributes. |
488 | if (!isa<BuiltinDialect>(Val: memorySpace.getDialect())) |
489 | return true; |
490 | |
491 | return false; |
492 | } |
493 | |
494 | Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace, |
495 | MLIRContext *ctx) { |
496 | if (memorySpace == 0) |
497 | return nullptr; |
498 | |
499 | return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace); |
500 | } |
501 | |
502 | Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) { |
503 | IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace); |
504 | if (intMemorySpace && intMemorySpace.getValue() == 0) |
505 | return nullptr; |
506 | |
507 | return memorySpace; |
508 | } |
509 | |
510 | unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) { |
511 | if (!memorySpace) |
512 | return 0; |
513 | |
514 | assert(llvm::isa<IntegerAttr>(memorySpace) && |
515 | "Using `getMemorySpaceInteger` with non-Integer attribute" ); |
516 | |
517 | return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt()); |
518 | } |
519 | |
520 | unsigned MemRefType::getMemorySpaceAsInt() const { |
521 | return detail::getMemorySpaceAsInt(getMemorySpace()); |
522 | } |
523 | |
524 | MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, |
525 | MemRefLayoutAttrInterface layout, |
526 | Attribute memorySpace) { |
527 | // Use default layout for empty attribute. |
528 | if (!layout) |
529 | layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( |
530 | shape.size(), elementType.getContext())); |
531 | |
532 | // Drop default memory space value and replace it with empty attribute. |
533 | memorySpace = skipDefaultMemorySpace(memorySpace); |
534 | |
535 | return Base::get(elementType.getContext(), shape, elementType, layout, |
536 | memorySpace); |
537 | } |
538 | |
539 | MemRefType MemRefType::getChecked( |
540 | function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape, |
541 | Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) { |
542 | |
543 | // Use default layout for empty attribute. |
544 | if (!layout) |
545 | layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( |
546 | shape.size(), elementType.getContext())); |
547 | |
548 | // Drop default memory space value and replace it with empty attribute. |
549 | memorySpace = skipDefaultMemorySpace(memorySpace); |
550 | |
551 | return Base::getChecked(emitErrorFn, elementType.getContext(), shape, |
552 | elementType, layout, memorySpace); |
553 | } |
554 | |
555 | MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, |
556 | AffineMap map, Attribute memorySpace) { |
557 | |
558 | // Use default layout for empty map. |
559 | if (!map) |
560 | map = AffineMap::getMultiDimIdentityMap(shape.size(), |
561 | elementType.getContext()); |
562 | |
563 | // Wrap AffineMap into Attribute. |
564 | auto layout = AffineMapAttr::get(map); |
565 | |
566 | // Drop default memory space value and replace it with empty attribute. |
567 | memorySpace = skipDefaultMemorySpace(memorySpace); |
568 | |
569 | return Base::get(elementType.getContext(), shape, elementType, layout, |
570 | memorySpace); |
571 | } |
572 | |
573 | MemRefType |
574 | MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, |
575 | ArrayRef<int64_t> shape, Type elementType, AffineMap map, |
576 | Attribute memorySpace) { |
577 | |
578 | // Use default layout for empty map. |
579 | if (!map) |
580 | map = AffineMap::getMultiDimIdentityMap(shape.size(), |
581 | elementType.getContext()); |
582 | |
583 | // Wrap AffineMap into Attribute. |
584 | auto layout = AffineMapAttr::get(map); |
585 | |
586 | // Drop default memory space value and replace it with empty attribute. |
587 | memorySpace = skipDefaultMemorySpace(memorySpace); |
588 | |
589 | return Base::getChecked(emitErrorFn, elementType.getContext(), shape, |
590 | elementType, layout, memorySpace); |
591 | } |
592 | |
593 | MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, |
594 | AffineMap map, unsigned memorySpaceInd) { |
595 | |
596 | // Use default layout for empty map. |
597 | if (!map) |
598 | map = AffineMap::getMultiDimIdentityMap(shape.size(), |
599 | elementType.getContext()); |
600 | |
601 | // Wrap AffineMap into Attribute. |
602 | auto layout = AffineMapAttr::get(map); |
603 | |
604 | // Convert deprecated integer-like memory space to Attribute. |
605 | Attribute memorySpace = |
606 | wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); |
607 | |
608 | return Base::get(elementType.getContext(), shape, elementType, layout, |
609 | memorySpace); |
610 | } |
611 | |
612 | MemRefType |
613 | MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, |
614 | ArrayRef<int64_t> shape, Type elementType, AffineMap map, |
615 | unsigned memorySpaceInd) { |
616 | |
617 | // Use default layout for empty map. |
618 | if (!map) |
619 | map = AffineMap::getMultiDimIdentityMap(shape.size(), |
620 | elementType.getContext()); |
621 | |
622 | // Wrap AffineMap into Attribute. |
623 | auto layout = AffineMapAttr::get(map); |
624 | |
625 | // Convert deprecated integer-like memory space to Attribute. |
626 | Attribute memorySpace = |
627 | wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); |
628 | |
629 | return Base::getChecked(emitErrorFn, elementType.getContext(), shape, |
630 | elementType, layout, memorySpace); |
631 | } |
632 | |
633 | LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError, |
634 | ArrayRef<int64_t> shape, Type elementType, |
635 | MemRefLayoutAttrInterface layout, |
636 | Attribute memorySpace) { |
637 | if (!BaseMemRefType::isValidElementType(elementType)) |
638 | return emitError() << "invalid memref element type" ; |
639 | |
640 | // Negative sizes are not allowed except for `kDynamic`. |
641 | for (int64_t s : shape) |
642 | if (s < 0 && !ShapedType::isDynamic(s)) |
643 | return emitError() << "invalid memref size" ; |
644 | |
645 | assert(layout && "missing layout specification" ); |
646 | if (failed(layout.verifyLayout(shape, emitError))) |
647 | return failure(); |
648 | |
649 | if (!isSupportedMemorySpace(memorySpace)) |
650 | return emitError() << "unsupported memory space Attribute" ; |
651 | |
652 | return success(); |
653 | } |
654 | |
655 | //===----------------------------------------------------------------------===// |
656 | // UnrankedMemRefType |
657 | //===----------------------------------------------------------------------===// |
658 | |
659 | unsigned UnrankedMemRefType::getMemorySpaceAsInt() const { |
660 | return detail::getMemorySpaceAsInt(getMemorySpace()); |
661 | } |
662 | |
663 | LogicalResult |
664 | UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError, |
665 | Type elementType, Attribute memorySpace) { |
666 | if (!BaseMemRefType::isValidElementType(elementType)) |
667 | return emitError() << "invalid memref element type" ; |
668 | |
669 | if (!isSupportedMemorySpace(memorySpace)) |
670 | return emitError() << "unsupported memory space Attribute" ; |
671 | |
672 | return success(); |
673 | } |
674 | |
675 | // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( |
676 | // i.e. single term). Accumulate the AffineExpr into the existing one. |
677 | static void (AffineExpr e, |
678 | AffineExpr multiplicativeFactor, |
679 | MutableArrayRef<AffineExpr> strides, |
680 | AffineExpr &offset) { |
681 | if (auto dim = dyn_cast<AffineDimExpr>(Val&: e)) |
682 | strides[dim.getPosition()] = |
683 | strides[dim.getPosition()] + multiplicativeFactor; |
684 | else |
685 | offset = offset + e * multiplicativeFactor; |
686 | } |
687 | |
688 | /// Takes a single AffineExpr `e` and populates the `strides` array with the |
689 | /// strides expressions for each dim position. |
690 | /// The convention is that the strides for dimensions d0, .. dn appear in |
691 | /// order to make indexing intuitive into the result. |
692 | static LogicalResult (AffineExpr e, |
693 | AffineExpr multiplicativeFactor, |
694 | MutableArrayRef<AffineExpr> strides, |
695 | AffineExpr &offset) { |
696 | auto bin = dyn_cast<AffineBinaryOpExpr>(Val&: e); |
697 | if (!bin) { |
698 | extractStridesFromTerm(e, multiplicativeFactor, strides, offset); |
699 | return success(); |
700 | } |
701 | |
702 | if (bin.getKind() == AffineExprKind::CeilDiv || |
703 | bin.getKind() == AffineExprKind::FloorDiv || |
704 | bin.getKind() == AffineExprKind::Mod) |
705 | return failure(); |
706 | |
707 | if (bin.getKind() == AffineExprKind::Mul) { |
708 | auto dim = dyn_cast<AffineDimExpr>(Val: bin.getLHS()); |
709 | if (dim) { |
710 | strides[dim.getPosition()] = |
711 | strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; |
712 | return success(); |
713 | } |
714 | // LHS and RHS may both contain complex expressions of dims. Try one path |
715 | // and if it fails try the other. This is guaranteed to succeed because |
716 | // only one path may have a `dim`, otherwise this is not an AffineExpr in |
717 | // the first place. |
718 | if (bin.getLHS().isSymbolicOrConstant()) |
719 | return extractStrides(e: bin.getRHS(), multiplicativeFactor: multiplicativeFactor * bin.getLHS(), |
720 | strides, offset); |
721 | return extractStrides(e: bin.getLHS(), multiplicativeFactor: multiplicativeFactor * bin.getRHS(), |
722 | strides, offset); |
723 | } |
724 | |
725 | if (bin.getKind() == AffineExprKind::Add) { |
726 | auto res1 = |
727 | extractStrides(e: bin.getLHS(), multiplicativeFactor, strides, offset); |
728 | auto res2 = |
729 | extractStrides(e: bin.getRHS(), multiplicativeFactor, strides, offset); |
730 | return success(isSuccess: succeeded(result: res1) && succeeded(result: res2)); |
731 | } |
732 | |
733 | llvm_unreachable("unexpected binary operation" ); |
734 | } |
735 | |
736 | /// A stride specification is a list of integer values that are either static |
737 | /// or dynamic (encoded with ShapedType::kDynamic). Strides encode |
738 | /// the distance in the number of elements between successive entries along a |
739 | /// particular dimension. |
740 | /// |
741 | /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a |
742 | /// non-contiguous memory region of `42` by `16` `f32` elements in which the |
743 | /// distance between two consecutive elements along the outer dimension is `1` |
744 | /// and the distance between two consecutive elements along the inner dimension |
745 | /// is `64`. |
746 | /// |
747 | /// The convention is that the strides for dimensions d0, .. dn appear in |
748 | /// order to make indexing intuitive into the result. |
749 | static LogicalResult getStridesAndOffset(MemRefType t, |
750 | SmallVectorImpl<AffineExpr> &strides, |
751 | AffineExpr &offset) { |
752 | AffineMap m = t.getLayout().getAffineMap(); |
753 | |
754 | if (m.getNumResults() != 1 && !m.isIdentity()) |
755 | return failure(); |
756 | |
757 | auto zero = getAffineConstantExpr(0, t.getContext()); |
758 | auto one = getAffineConstantExpr(1, t.getContext()); |
759 | offset = zero; |
760 | strides.assign(t.getRank(), zero); |
761 | |
762 | // Canonical case for empty map. |
763 | if (m.isIdentity()) { |
764 | // 0-D corner case, offset is already 0. |
765 | if (t.getRank() == 0) |
766 | return success(); |
767 | auto stridedExpr = |
768 | makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); |
769 | if (succeeded(extractStrides(stridedExpr, one, strides, offset))) |
770 | return success(); |
771 | assert(false && "unexpected failure: extract strides in canonical layout" ); |
772 | } |
773 | |
774 | // Non-canonical case requires more work. |
775 | auto stridedExpr = |
776 | simplifyAffineExpr(expr: m.getResult(idx: 0), numDims: m.getNumDims(), numSymbols: m.getNumSymbols()); |
777 | if (failed(extractStrides(stridedExpr, one, strides, offset))) { |
778 | offset = AffineExpr(); |
779 | strides.clear(); |
780 | return failure(); |
781 | } |
782 | |
783 | // Simplify results to allow folding to constants and simple checks. |
784 | unsigned numDims = m.getNumDims(); |
785 | unsigned numSymbols = m.getNumSymbols(); |
786 | offset = simplifyAffineExpr(expr: offset, numDims, numSymbols); |
787 | for (auto &stride : strides) |
788 | stride = simplifyAffineExpr(expr: stride, numDims, numSymbols); |
789 | |
790 | // In practice, a strided memref must be internally non-aliasing. Test |
791 | // against 0 as a proxy. |
792 | // TODO: static cases can have more advanced checks. |
793 | // TODO: dynamic cases would require a way to compare symbolic |
794 | // expressions and would probably need an affine set context propagated |
795 | // everywhere. |
796 | if (llvm::any_of(Range&: strides, P: [](AffineExpr e) { |
797 | return e == getAffineConstantExpr(constant: 0, context: e.getContext()); |
798 | })) { |
799 | offset = AffineExpr(); |
800 | strides.clear(); |
801 | return failure(); |
802 | } |
803 | |
804 | return success(); |
805 | } |
806 | |
807 | LogicalResult mlir::getStridesAndOffset(MemRefType t, |
808 | SmallVectorImpl<int64_t> &strides, |
809 | int64_t &offset) { |
810 | // Happy path: the type uses the strided layout directly. |
811 | if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(t.getLayout())) { |
812 | llvm::append_range(strides, strided.getStrides()); |
813 | offset = strided.getOffset(); |
814 | return success(); |
815 | } |
816 | |
817 | // Otherwise, defer to the affine fallback as layouts are supposed to be |
818 | // convertible to affine maps. |
819 | AffineExpr offsetExpr; |
820 | SmallVector<AffineExpr, 4> strideExprs; |
821 | if (failed(::getStridesAndOffset(t: t, strides&: strideExprs, offset&: offsetExpr))) |
822 | return failure(); |
823 | if (auto cst = dyn_cast<AffineConstantExpr>(Val&: offsetExpr)) |
824 | offset = cst.getValue(); |
825 | else |
826 | offset = ShapedType::kDynamic; |
827 | for (auto e : strideExprs) { |
828 | if (auto c = dyn_cast<AffineConstantExpr>(Val&: e)) |
829 | strides.push_back(Elt: c.getValue()); |
830 | else |
831 | strides.push_back(ShapedType::kDynamic); |
832 | } |
833 | return success(); |
834 | } |
835 | |
836 | std::pair<SmallVector<int64_t>, int64_t> |
837 | mlir::getStridesAndOffset(MemRefType t) { |
838 | SmallVector<int64_t> strides; |
839 | int64_t offset; |
840 | LogicalResult status = getStridesAndOffset(t, strides, offset); |
841 | (void)status; |
842 | assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset" ); |
843 | return {strides, offset}; |
844 | } |
845 | |
846 | //===----------------------------------------------------------------------===// |
847 | /// TupleType |
848 | //===----------------------------------------------------------------------===// |
849 | |
850 | /// Return the elements types for this tuple. |
851 | ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } |
852 | |
853 | /// Accumulate the types contained in this tuple and tuples nested within it. |
854 | /// Note that this only flattens nested tuples, not any other container type, |
855 | /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to |
856 | /// (i32, tensor<i32>, f32, i64) |
857 | void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { |
858 | for (Type type : getTypes()) { |
859 | if (auto nestedTuple = llvm::dyn_cast<TupleType>(type)) |
860 | nestedTuple.getFlattenedTypes(types); |
861 | else |
862 | types.push_back(type); |
863 | } |
864 | } |
865 | |
866 | /// Return the number of element types. |
867 | size_t TupleType::size() const { return getImpl()->size(); } |
868 | |
869 | //===----------------------------------------------------------------------===// |
870 | // Type Utilities |
871 | //===----------------------------------------------------------------------===// |
872 | |
873 | /// Return a version of `t` with identity layout if it can be determined |
874 | /// statically that the layout is the canonical contiguous strided layout. |
875 | /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of |
876 | /// `t` with simplified layout. |
877 | /// If `t` has multiple layout maps or a multi-result layout, just return `t`. |
878 | MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { |
879 | AffineMap m = t.getLayout().getAffineMap(); |
880 | |
881 | // Already in canonical form. |
882 | if (m.isIdentity()) |
883 | return t; |
884 | |
885 | // Can't reduce to canonical identity form, return in canonical form. |
886 | if (m.getNumResults() > 1) |
887 | return t; |
888 | |
889 | // Corner-case for 0-D affine maps. |
890 | if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { |
891 | if (auto cst = dyn_cast<AffineConstantExpr>(m.getResult(0))) |
892 | if (cst.getValue() == 0) |
893 | return MemRefType::Builder(t).setLayout({}); |
894 | return t; |
895 | } |
896 | |
897 | // 0-D corner case for empty shape that still have an affine map. Example: |
898 | // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose |
899 | // offset needs to remain, just return t. |
900 | if (t.getShape().empty()) |
901 | return t; |
902 | |
903 | // If the canonical strided layout for the sizes of `t` is equal to the |
904 | // simplified layout of `t` we can just return an empty layout. Otherwise, |
905 | // just simplify the existing layout. |
906 | AffineExpr expr = |
907 | makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); |
908 | auto simplifiedLayoutExpr = |
909 | simplifyAffineExpr(expr: m.getResult(idx: 0), numDims: m.getNumDims(), numSymbols: m.getNumSymbols()); |
910 | if (expr != simplifiedLayoutExpr) |
911 | return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get( |
912 | m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr))); |
913 | return MemRefType::Builder(t).setLayout({}); |
914 | } |
915 | |
916 | AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, |
917 | ArrayRef<AffineExpr> exprs, |
918 | MLIRContext *context) { |
919 | // Size 0 corner case is useful for canonicalizations. |
920 | if (sizes.empty()) |
921 | return getAffineConstantExpr(constant: 0, context); |
922 | |
923 | assert(!exprs.empty() && "expected exprs" ); |
924 | auto maps = AffineMap::inferFromExprList(exprsList: exprs, context); |
925 | assert(!maps.empty() && "Expected one non-empty map" ); |
926 | unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); |
927 | |
928 | AffineExpr expr; |
929 | bool dynamicPoisonBit = false; |
930 | int64_t runningSize = 1; |
931 | for (auto en : llvm::zip(t: llvm::reverse(C&: exprs), u: llvm::reverse(C&: sizes))) { |
932 | int64_t size = std::get<1>(t&: en); |
933 | AffineExpr dimExpr = std::get<0>(t&: en); |
934 | AffineExpr stride = dynamicPoisonBit |
935 | ? getAffineSymbolExpr(position: nSymbols++, context) |
936 | : getAffineConstantExpr(constant: runningSize, context); |
937 | expr = expr ? expr + dimExpr * stride : dimExpr * stride; |
938 | if (size > 0) { |
939 | runningSize *= size; |
940 | assert(runningSize > 0 && "integer overflow in size computation" ); |
941 | } else { |
942 | dynamicPoisonBit = true; |
943 | } |
944 | } |
945 | return simplifyAffineExpr(expr, numDims, numSymbols: nSymbols); |
946 | } |
947 | |
948 | AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, |
949 | MLIRContext *context) { |
950 | SmallVector<AffineExpr, 4> exprs; |
951 | exprs.reserve(N: sizes.size()); |
952 | for (auto dim : llvm::seq<unsigned>(Begin: 0, End: sizes.size())) |
953 | exprs.push_back(Elt: getAffineDimExpr(position: dim, context)); |
954 | return makeCanonicalStridedLayoutExpr(sizes, exprs, context); |
955 | } |
956 | |
957 | bool mlir::isStrided(MemRefType t) { |
958 | int64_t offset; |
959 | SmallVector<int64_t, 4> strides; |
960 | auto res = getStridesAndOffset(t, strides, offset); |
961 | return succeeded(res); |
962 | } |
963 | |
964 | bool mlir::isLastMemrefDimUnitStride(MemRefType type) { |
965 | int64_t offset; |
966 | SmallVector<int64_t> strides; |
967 | auto successStrides = getStridesAndOffset(type, strides, offset); |
968 | return succeeded(successStrides) && (strides.empty() || strides.back() == 1); |
969 | } |
970 | |
971 | bool mlir::trailingNDimsContiguous(MemRefType type, int64_t n) { |
972 | if (!isLastMemrefDimUnitStride(type)) |
973 | return false; |
974 | |
975 | auto memrefShape = type.getShape().take_back(n); |
976 | if (ShapedType::isDynamicShape(memrefShape)) |
977 | return false; |
978 | |
979 | if (type.getLayout().isIdentity()) |
980 | return true; |
981 | |
982 | int64_t offset; |
983 | SmallVector<int64_t> stridesFull; |
984 | if (!succeeded(getStridesAndOffset(type, stridesFull, offset))) |
985 | return false; |
986 | auto strides = ArrayRef<int64_t>(stridesFull).take_back(N: n); |
987 | |
988 | if (strides.empty()) |
989 | return true; |
990 | |
991 | // Check whether strides match "flattened" dims. |
992 | SmallVector<int64_t> flattenedDims; |
993 | auto dimProduct = 1; |
994 | for (auto dim : llvm::reverse(memrefShape.drop_front(1))) { |
995 | dimProduct *= dim; |
996 | flattenedDims.push_back(dimProduct); |
997 | } |
998 | |
999 | strides = strides.drop_back(N: 1); |
1000 | return llvm::equal(LRange&: strides, RRange: llvm::reverse(C&: flattenedDims)); |
1001 | } |
1002 | |