1 | //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" |
10 | |
11 | #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" |
12 | #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" |
13 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
14 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" |
15 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" |
16 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
17 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
18 | #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" |
19 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
20 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
21 | #include "mlir/Pass/Pass.h" |
22 | |
23 | namespace mlir { |
24 | #define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS |
25 | #include "mlir/Conversion/Passes.h.inc" |
26 | } // namespace mlir |
27 | |
28 | using namespace mlir; |
29 | |
30 | namespace { |
31 | |
32 | /// A pattern that converts the result and operand types, attributes, and region |
33 | /// arguments of an OpenMP operation to the LLVM dialect. |
34 | /// |
35 | /// Attributes are copied verbatim by default, and only translated if they are |
36 | /// type attributes. |
37 | /// |
38 | /// Region bodies, if any, are not modified and expected to either be processed |
39 | /// by the conversion infrastructure or already contain ops compatible with LLVM |
40 | /// dialect types. |
41 | template <typename T> |
42 | struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> { |
43 | using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern; |
44 | |
45 | LogicalResult |
46 | matchAndRewrite(T op, typename T::Adaptor adaptor, |
47 | ConversionPatternRewriter &rewriter) const override { |
48 | // Translate result types. |
49 | const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); |
50 | SmallVector<Type> resTypes; |
51 | if (failed(converter->convertTypes(types: op->getResultTypes(), results&: resTypes))) |
52 | return failure(); |
53 | |
54 | // Translate type attributes. |
55 | // They are kept unmodified except if they are type attributes. |
56 | SmallVector<NamedAttribute> convertedAttrs; |
57 | for (NamedAttribute attr : op->getAttrs()) { |
58 | if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) { |
59 | Type convertedType = converter->convertType(typeAttr.getValue()); |
60 | convertedAttrs.emplace_back(attr.getName(), |
61 | TypeAttr::get(convertedType)); |
62 | } else { |
63 | convertedAttrs.push_back(attr); |
64 | } |
65 | } |
66 | |
67 | // Translate operands. |
68 | SmallVector<Value> convertedOperands; |
69 | convertedOperands.reserve(op->getNumOperands()); |
70 | for (auto [originalOperand, convertedOperand] : |
71 | llvm::zip_equal(op->getOperands(), adaptor.getOperands())) { |
72 | if (!originalOperand) |
73 | return failure(); |
74 | |
75 | // TODO: Revisit whether we need to trigger an error specifically for this |
76 | // set of operations. Consider removing this check or updating the list. |
77 | if constexpr (llvm::is_one_of<T, omp::AtomicUpdateOp, omp::AtomicWriteOp, |
78 | omp::FlushOp, omp::MapBoundsOp, |
79 | omp::ThreadprivateOp>::value) { |
80 | if (isa<MemRefType>(originalOperand.getType())) { |
81 | // TODO: Support memref type in variable operands |
82 | return rewriter.notifyMatchFailure(op, "memref is not supported yet" ); |
83 | } |
84 | } |
85 | convertedOperands.push_back(convertedOperand); |
86 | } |
87 | |
88 | // Create new operation. |
89 | auto newOp = rewriter.create<T>(op.getLoc(), resTypes, convertedOperands, |
90 | convertedAttrs); |
91 | |
92 | // Translate regions. |
93 | for (auto [originalRegion, convertedRegion] : |
94 | llvm::zip_equal(op->getRegions(), newOp->getRegions())) { |
95 | rewriter.inlineRegionBefore(originalRegion, convertedRegion, |
96 | convertedRegion.end()); |
97 | if (failed(rewriter.convertRegionTypes(&convertedRegion, |
98 | *this->getTypeConverter()))) |
99 | return failure(); |
100 | } |
101 | |
102 | // Delete old operation and replace result uses with those of the new one. |
103 | rewriter.replaceOp(op, newOp->getResults()); |
104 | return success(); |
105 | } |
106 | }; |
107 | |
108 | } // namespace |
109 | |
110 | void mlir::configureOpenMPToLLVMConversionLegality( |
111 | ConversionTarget &target, const LLVMTypeConverter &typeConverter) { |
112 | target.addDynamicallyLegalOp< |
113 | #define GET_OP_LIST |
114 | #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" |
115 | >([&](Operation *op) { |
116 | return typeConverter.isLegal(op->getOperandTypes()) && |
117 | typeConverter.isLegal(op->getResultTypes()) && |
118 | llvm::all_of(op->getRegions(), |
119 | [&](Region ®ion) { |
120 | return typeConverter.isLegal(®ion); |
121 | }) && |
122 | llvm::all_of(op->getAttrs(), [&](NamedAttribute attr) { |
123 | auto typeAttr = dyn_cast<TypeAttr>(attr.getValue()); |
124 | return !typeAttr || typeConverter.isLegal(typeAttr.getValue()); |
125 | }); |
126 | }); |
127 | } |
128 | |
129 | /// Add an `OpenMPOpConversion<T>` conversion pattern for each operation type |
130 | /// passed as template argument. |
131 | template <typename... Ts> |
132 | static inline RewritePatternSet & |
133 | addOpenMPOpConversions(LLVMTypeConverter &converter, |
134 | RewritePatternSet &patterns) { |
135 | return patterns.add<OpenMPOpConversion<Ts>...>(converter); |
136 | } |
137 | |
138 | void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, |
139 | RewritePatternSet &patterns) { |
140 | // This type is allowed when converting OpenMP to LLVM Dialect, it carries |
141 | // bounds information for map clauses and the operation and type are |
142 | // discarded on lowering to LLVM-IR from the OpenMP dialect. |
143 | converter.addConversion( |
144 | callback: [&](omp::MapBoundsType type) -> Type { return type; }); |
145 | |
146 | // Add conversions for all OpenMP operations. |
147 | addOpenMPOpConversions< |
148 | #define GET_OP_LIST |
149 | #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" |
150 | >(converter, patterns); |
151 | } |
152 | |
153 | namespace { |
154 | struct ConvertOpenMPToLLVMPass |
155 | : public impl::ConvertOpenMPToLLVMPassBase<ConvertOpenMPToLLVMPass> { |
156 | using Base::Base; |
157 | |
158 | void runOnOperation() override; |
159 | }; |
160 | } // namespace |
161 | |
162 | void ConvertOpenMPToLLVMPass::runOnOperation() { |
163 | auto module = getOperation(); |
164 | |
165 | // Convert to OpenMP operations with LLVM IR dialect |
166 | RewritePatternSet patterns(&getContext()); |
167 | LLVMTypeConverter converter(&getContext()); |
168 | arith::populateArithToLLVMConversionPatterns(converter, patterns); |
169 | cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); |
170 | cf::populateAssertToLLVMConversionPattern(converter, patterns); |
171 | populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns); |
172 | populateFuncToLLVMConversionPatterns(converter, patterns); |
173 | populateOpenMPToLLVMConversionPatterns(converter, patterns); |
174 | |
175 | LLVMConversionTarget target(getContext()); |
176 | target.addLegalOp<omp::BarrierOp, omp::FlushOp, omp::TaskwaitOp, |
177 | omp::TaskyieldOp, omp::TerminatorOp>(); |
178 | configureOpenMPToLLVMConversionLegality(target, typeConverter: converter); |
179 | if (failed(applyPartialConversion(module, target, std::move(patterns)))) |
180 | signalPassFailure(); |
181 | } |
182 | |
183 | //===----------------------------------------------------------------------===// |
184 | // ConvertToLLVMPatternInterface implementation |
185 | //===----------------------------------------------------------------------===// |
186 | namespace { |
187 | /// Implement the interface to convert OpenMP to LLVM. |
188 | struct OpenMPToLLVMDialectInterface : public ConvertToLLVMPatternInterface { |
189 | using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; |
190 | void loadDependentDialects(MLIRContext *context) const final { |
191 | context->loadDialect<LLVM::LLVMDialect>(); |
192 | } |
193 | |
194 | /// Hook for derived dialect interface to provide conversion patterns |
195 | /// and mark dialect legal for the conversion target. |
196 | void populateConvertToLLVMConversionPatterns( |
197 | ConversionTarget &target, LLVMTypeConverter &typeConverter, |
198 | RewritePatternSet &patterns) const final { |
199 | configureOpenMPToLLVMConversionLegality(target, typeConverter); |
200 | populateOpenMPToLLVMConversionPatterns(converter&: typeConverter, patterns); |
201 | } |
202 | }; |
203 | } // namespace |
204 | |
205 | void mlir::registerConvertOpenMPToLLVMInterface(DialectRegistry ®istry) { |
206 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, omp::OpenMPDialect *dialect) { |
207 | dialect->addInterfaces<OpenMPToLLVMDialectInterface>(); |
208 | }); |
209 | } |
210 | |