1 | //===- Dialect.cpp - Implementation of the linalg dialect and types -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the Linalg dialect types and dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
16 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
17 | #include "mlir/Dialect/Math/IR/Math.h" |
18 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
19 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" |
20 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
21 | #include "mlir/IR/BuiltinTypes.h" |
22 | #include "mlir/IR/Dialect.h" |
23 | #include "mlir/IR/DialectImplementation.h" |
24 | #include "mlir/Interfaces/DestinationStyleOpInterface.h" |
25 | #include "mlir/Interfaces/FunctionInterfaces.h" |
26 | #include "mlir/Interfaces/SubsetOpInterface.h" |
27 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
28 | #include "mlir/Parser/Parser.h" |
29 | #include "mlir/Support/LLVM.h" |
30 | #include "mlir/Transforms/InliningUtils.h" |
31 | |
32 | #include "llvm/ADT/StringExtras.h" |
33 | #include "llvm/ADT/TypeSwitch.h" |
34 | #include "llvm/Support/raw_ostream.h" |
35 | |
36 | using namespace mlir; |
37 | using namespace mlir::linalg; |
38 | |
39 | //===----------------------------------------------------------------------===// |
40 | // LinalgDialect Dialect Interfaces |
41 | //===----------------------------------------------------------------------===// |
42 | |
43 | namespace { |
44 | |
45 | struct LinalgInlinerInterface : public DialectInlinerInterface { |
46 | using DialectInlinerInterface::DialectInlinerInterface; |
47 | |
48 | // We don't have any special restrictions on what can be inlined into |
49 | // destination regions (e.g. while/conditional bodies). Always allow it. |
50 | bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, |
51 | IRMapping &valueMapping) const final { |
52 | return true; |
53 | } |
54 | // Operations in Linalg dialect are always legal to inline. |
55 | bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { |
56 | return true; |
57 | } |
58 | // Handle the given inlined terminator by replacing it with a new operation |
59 | // as necessary. Required when the region has only one block. |
60 | void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {} |
61 | }; |
62 | |
63 | } // namespace |
64 | |
65 | //===----------------------------------------------------------------------===// |
66 | // LinalgDialect |
67 | //===----------------------------------------------------------------------===// |
68 | |
69 | /// Attribute name used to memoize indexing maps for named ops. |
70 | constexpr const ::llvm::StringLiteral |
71 | LinalgDialect::kMemoizedIndexingMapsAttrName; |
72 | |
73 | /// Trait to check if T provides a `regionBuilder` method. |
74 | template <typename T, typename... Args> |
75 | using has_region_builder = decltype(T::regionBuilder); |
76 | template <typename T> |
77 | using detect_has_region_builder = llvm::is_detected<has_region_builder, T>; |
78 | |
79 | /// SFINAE helper for single C++ class without a `regionBuilder` method (e.g. |
80 | /// an OpInterface). |
81 | template <typename OpType, typename = std::enable_if_t< |
82 | !detect_has_region_builder<OpType>::value>> |
83 | void addNamedOpBuilderImpl( |
84 | llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) { |
85 | // Do nothing. |
86 | } |
87 | |
88 | template <typename OpType, |
89 | typename = std::enable_if_t<detect_has_region_builder<OpType>::value>, |
90 | typename = void> |
91 | void addNamedOpBuilderImpl( |
92 | llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) { |
93 | map.insert(std::make_pair( |
94 | OpType::getOperationName(), |
95 | static_cast<LinalgDialect::RegionBuilderFunType>(OpType::regionBuilder))); |
96 | } |
97 | |
98 | template <typename... OpTypes> |
99 | void addNamedOpBuilders( |
100 | llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) { |
101 | (addNamedOpBuilderImpl<OpTypes>(map), ...); |
102 | } |
103 | |
104 | void mlir::linalg::LinalgDialect::initialize() { |
105 | addAttributes< |
106 | #define GET_ATTRDEF_LIST |
107 | #include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.cpp.inc" |
108 | >(); |
109 | addOperations< |
110 | #define GET_OP_LIST |
111 | #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" |
112 | >(); |
113 | addOperations< |
114 | #define GET_OP_LIST |
115 | #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" |
116 | >(); |
117 | |
118 | // Fill the Linalg-specific OpName to RegionBuilder map. |
119 | addNamedOpBuilders< |
120 | #define GET_OP_LIST |
121 | #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" |
122 | >(namedStructuredOpRegionBuilders); |
123 | |
124 | addInterfaces<LinalgInlinerInterface>(); |
125 | |
126 | declarePromisedInterface<mesh::ShardingInterface, GenericOp>(); |
127 | declarePromisedInterfaces<mesh::ShardingInterface, |
128 | #define GET_OP_LIST |
129 | #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" |
130 | >(); |
131 | declarePromisedInterface<SubsetOpInterface, CopyOp>(); |
132 | declarePromisedInterface<SubsetInsertionOpInterface, CopyOp>(); |
133 | declarePromisedInterface<ValueBoundsOpInterface, IndexOp>(); |
134 | declarePromisedInterface<TilingInterface, linalg::GenericOp>(); |
135 | declarePromisedInterface<PartialReductionOpInterface, linalg::GenericOp>(); |
136 | declarePromisedInterfaces<TilingInterface, |
137 | #define GET_OP_LIST |
138 | #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" |
139 | >(); |
140 | declarePromisedInterfaces<PartialReductionOpInterface, |
141 | #define GET_OP_LIST |
142 | #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" |
143 | >(); |
144 | declarePromisedInterfaces<bufferization::BufferizableOpInterface, |
145 | #define GET_OP_LIST |
146 | #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" |
147 | >(); |
148 | } |
149 | |
150 | LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op, |
151 | NamedAttribute attr) { |
152 | if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName) |
153 | return success(); |
154 | return op->emitError() << "attribute '" << attr.getName() |
155 | << "' not supported by the linalg dialect" ; |
156 | } |
157 | |
158 | #include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.cpp.inc" |
159 | |
160 | #define GET_ATTRDEF_CLASSES |
161 | #include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.cpp.inc" |
162 | |
163 | #include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc" |
164 | |