1 | //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" |
10 | |
11 | #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" |
12 | #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" |
13 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" |
14 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" |
15 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
16 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
17 | #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" |
18 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
19 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
20 | #include "mlir/Pass/Pass.h" |
21 | |
22 | namespace mlir { |
23 | #define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS |
24 | #include "mlir/Conversion/Passes.h.inc" |
25 | } // namespace mlir |
26 | |
27 | using namespace mlir; |
28 | |
29 | namespace { |
30 | /// A pattern that converts the region arguments in a single-region OpenMP |
31 | /// operation to the LLVM dialect. The body of the region is not modified and is |
32 | /// expected to either be processed by the conversion infrastructure or already |
33 | /// contain ops compatible with LLVM dialect types. |
34 | template <typename OpType> |
35 | struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> { |
36 | using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern; |
37 | |
38 | LogicalResult |
39 | matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor, |
40 | ConversionPatternRewriter &rewriter) const override { |
41 | auto newOp = rewriter.create<OpType>( |
42 | curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs()); |
43 | rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(), |
44 | newOp.getRegion().end()); |
45 | if (failed(rewriter.convertRegionTypes(region: &newOp.getRegion(), |
46 | converter: *this->getTypeConverter()))) |
47 | return failure(); |
48 | |
49 | rewriter.eraseOp(op: curOp); |
50 | return success(); |
51 | } |
52 | }; |
53 | |
54 | template <typename T> |
55 | struct RegionLessOpWithVarOperandsConversion |
56 | : public ConvertOpToLLVMPattern<T> { |
57 | using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern; |
58 | LogicalResult |
59 | matchAndRewrite(T curOp, typename T::Adaptor adaptor, |
60 | ConversionPatternRewriter &rewriter) const override { |
61 | const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); |
62 | SmallVector<Type> resTypes; |
63 | if (failed(converter->convertTypes(types: curOp->getResultTypes(), results&: resTypes))) |
64 | return failure(); |
65 | SmallVector<Value> convertedOperands; |
66 | assert(curOp.getNumVariableOperands() == |
67 | curOp.getOperation()->getNumOperands() && |
68 | "unexpected non-variable operands" ); |
69 | for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) { |
70 | Value originalVariableOperand = curOp.getVariableOperand(idx); |
71 | if (!originalVariableOperand) |
72 | return failure(); |
73 | if (isa<MemRefType>(originalVariableOperand.getType())) { |
74 | // TODO: Support memref type in variable operands |
75 | return rewriter.notifyMatchFailure(curOp, |
76 | "memref is not supported yet" ); |
77 | } |
78 | convertedOperands.emplace_back(adaptor.getOperands()[idx]); |
79 | } |
80 | |
81 | rewriter.replaceOpWithNewOp<T>(curOp, resTypes, convertedOperands, |
82 | curOp->getAttrs()); |
83 | return success(); |
84 | } |
85 | }; |
86 | |
87 | template <typename T> |
88 | struct RegionOpWithVarOperandsConversion : public ConvertOpToLLVMPattern<T> { |
89 | using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern; |
90 | LogicalResult |
91 | matchAndRewrite(T curOp, typename T::Adaptor adaptor, |
92 | ConversionPatternRewriter &rewriter) const override { |
93 | const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); |
94 | SmallVector<Type> resTypes; |
95 | if (failed(converter->convertTypes(types: curOp->getResultTypes(), results&: resTypes))) |
96 | return failure(); |
97 | SmallVector<Value> convertedOperands; |
98 | assert(curOp.getNumVariableOperands() == |
99 | curOp.getOperation()->getNumOperands() && |
100 | "unexpected non-variable operands" ); |
101 | for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) { |
102 | Value originalVariableOperand = curOp.getVariableOperand(idx); |
103 | if (!originalVariableOperand) |
104 | return failure(); |
105 | if (isa<MemRefType>(originalVariableOperand.getType())) { |
106 | // TODO: Support memref type in variable operands |
107 | return rewriter.notifyMatchFailure(curOp, |
108 | "memref is not supported yet" ); |
109 | } |
110 | convertedOperands.emplace_back(adaptor.getOperands()[idx]); |
111 | } |
112 | auto newOp = rewriter.create<T>(curOp.getLoc(), resTypes, convertedOperands, |
113 | curOp->getAttrs()); |
114 | rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(), |
115 | newOp.getRegion().end()); |
116 | if (failed(rewriter.convertRegionTypes(region: &newOp.getRegion(), |
117 | converter: *this->getTypeConverter()))) |
118 | return failure(); |
119 | |
120 | rewriter.eraseOp(op: curOp); |
121 | return success(); |
122 | } |
123 | }; |
124 | |
125 | template <typename T> |
126 | struct RegionLessOpConversion : public ConvertOpToLLVMPattern<T> { |
127 | using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern; |
128 | LogicalResult |
129 | matchAndRewrite(T curOp, typename T::Adaptor adaptor, |
130 | ConversionPatternRewriter &rewriter) const override { |
131 | const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); |
132 | SmallVector<Type> resTypes; |
133 | if (failed(converter->convertTypes(types: curOp->getResultTypes(), results&: resTypes))) |
134 | return failure(); |
135 | |
136 | rewriter.replaceOpWithNewOp<T>(curOp, resTypes, adaptor.getOperands(), |
137 | curOp->getAttrs()); |
138 | return success(); |
139 | } |
140 | }; |
141 | |
142 | struct AtomicReadOpConversion |
143 | : public ConvertOpToLLVMPattern<omp::AtomicReadOp> { |
144 | using ConvertOpToLLVMPattern<omp::AtomicReadOp>::ConvertOpToLLVMPattern; |
145 | LogicalResult |
146 | matchAndRewrite(omp::AtomicReadOp curOp, OpAdaptor adaptor, |
147 | ConversionPatternRewriter &rewriter) const override { |
148 | const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); |
149 | Type curElementType = curOp.getElementType(); |
150 | auto newOp = rewriter.create<omp::AtomicReadOp>( |
151 | curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs()); |
152 | TypeAttr typeAttr = TypeAttr::get(converter->convertType(curElementType)); |
153 | newOp.setElementTypeAttr(typeAttr); |
154 | rewriter.eraseOp(op: curOp); |
155 | return success(); |
156 | } |
157 | }; |
158 | |
159 | struct MapInfoOpConversion : public ConvertOpToLLVMPattern<omp::MapInfoOp> { |
160 | using ConvertOpToLLVMPattern<omp::MapInfoOp>::ConvertOpToLLVMPattern; |
161 | LogicalResult |
162 | matchAndRewrite(omp::MapInfoOp curOp, OpAdaptor adaptor, |
163 | ConversionPatternRewriter &rewriter) const override { |
164 | const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); |
165 | |
166 | SmallVector<Type> resTypes; |
167 | if (failed(converter->convertTypes(types: curOp->getResultTypes(), results&: resTypes))) |
168 | return failure(); |
169 | |
170 | // Copy attributes of the curOp except for the typeAttr which should |
171 | // be converted |
172 | SmallVector<NamedAttribute> newAttrs; |
173 | for (NamedAttribute attr : curOp->getAttrs()) { |
174 | if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) { |
175 | Type newAttr = converter->convertType(typeAttr.getValue()); |
176 | newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr)); |
177 | } else { |
178 | newAttrs.push_back(attr); |
179 | } |
180 | } |
181 | |
182 | rewriter.replaceOpWithNewOp<omp::MapInfoOp>( |
183 | curOp, resTypes, adaptor.getOperands(), newAttrs); |
184 | return success(); |
185 | } |
186 | }; |
187 | |
188 | struct ReductionOpConversion : public ConvertOpToLLVMPattern<omp::ReductionOp> { |
189 | using ConvertOpToLLVMPattern<omp::ReductionOp>::ConvertOpToLLVMPattern; |
190 | LogicalResult |
191 | matchAndRewrite(omp::ReductionOp curOp, OpAdaptor adaptor, |
192 | ConversionPatternRewriter &rewriter) const override { |
193 | if (isa<MemRefType>(curOp.getAccumulator().getType())) { |
194 | // TODO: Support memref type in variable operands |
195 | return rewriter.notifyMatchFailure(curOp, "memref is not supported yet" ); |
196 | } |
197 | rewriter.replaceOpWithNewOp<omp::ReductionOp>( |
198 | curOp, TypeRange(), adaptor.getOperands(), curOp->getAttrs()); |
199 | return success(); |
200 | } |
201 | }; |
202 | |
203 | template <typename OpType> |
204 | struct MultiRegionOpConversion : public ConvertOpToLLVMPattern<OpType> { |
205 | using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern; |
206 | |
207 | void forwardOpAttrs(OpType curOp, OpType newOp) const {} |
208 | |
209 | LogicalResult |
210 | matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor, |
211 | ConversionPatternRewriter &rewriter) const override { |
212 | auto newOp = rewriter.create<OpType>( |
213 | curOp.getLoc(), TypeRange(), curOp.getSymNameAttr(), |
214 | TypeAttr::get(this->getTypeConverter()->convertType( |
215 | curOp.getTypeAttr().getValue()))); |
216 | forwardOpAttrs(curOp, newOp: newOp); |
217 | |
218 | for (unsigned idx = 0; idx < curOp.getNumRegions(); idx++) { |
219 | rewriter.inlineRegionBefore(curOp.getRegion(idx), newOp.getRegion(idx), |
220 | newOp.getRegion(idx).end()); |
221 | if (failed(rewriter.convertRegionTypes(region: &newOp.getRegion(idx), |
222 | converter: *this->getTypeConverter()))) |
223 | return failure(); |
224 | } |
225 | |
226 | rewriter.eraseOp(op: curOp); |
227 | return success(); |
228 | } |
229 | }; |
230 | |
231 | template <> |
232 | void MultiRegionOpConversion<omp::PrivateClauseOp>::forwardOpAttrs( |
233 | omp::PrivateClauseOp curOp, omp::PrivateClauseOp newOp) const { |
234 | newOp.setDataSharingType(curOp.getDataSharingType()); |
235 | } |
236 | } // namespace |
237 | |
238 | void mlir::configureOpenMPToLLVMConversionLegality( |
239 | ConversionTarget &target, LLVMTypeConverter &typeConverter) { |
240 | target.addDynamicallyLegalOp< |
241 | mlir::omp::AtomicReadOp, mlir::omp::AtomicWriteOp, mlir::omp::FlushOp, |
242 | mlir::omp::ThreadprivateOp, mlir::omp::YieldOp, |
243 | mlir::omp::TargetEnterDataOp, mlir::omp::TargetExitDataOp, |
244 | mlir::omp::TargetUpdateOp, mlir::omp::MapBoundsOp, mlir::omp::MapInfoOp>( |
245 | [&](Operation *op) { |
246 | return typeConverter.isLegal(op->getOperandTypes()) && |
247 | typeConverter.isLegal(op->getResultTypes()); |
248 | }); |
249 | target.addDynamicallyLegalOp<mlir::omp::ReductionOp>([&](Operation *op) { |
250 | return typeConverter.isLegal(op->getOperandTypes()); |
251 | }); |
252 | target.addDynamicallyLegalOp< |
253 | mlir::omp::AtomicUpdateOp, mlir::omp::CriticalOp, mlir::omp::TargetOp, |
254 | mlir::omp::TargetDataOp, mlir::omp::LoopNestOp, |
255 | mlir::omp::OrderedRegionOp, mlir::omp::ParallelOp, mlir::omp::WsloopOp, |
256 | mlir::omp::SimdOp, mlir::omp::MasterOp, mlir::omp::SectionOp, |
257 | mlir::omp::SectionsOp, mlir::omp::SingleOp, mlir::omp::TaskgroupOp, |
258 | mlir::omp::TaskOp, mlir::omp::DeclareReductionOp, |
259 | mlir::omp::PrivateClauseOp>([&](Operation *op) { |
260 | return std::all_of(op->getRegions().begin(), op->getRegions().end(), |
261 | [&](Region ®ion) { |
262 | return typeConverter.isLegal(®ion); |
263 | }) && |
264 | typeConverter.isLegal(op->getOperandTypes()) && |
265 | typeConverter.isLegal(op->getResultTypes()); |
266 | }); |
267 | } |
268 | |
269 | void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, |
270 | RewritePatternSet &patterns) { |
271 | // This type is allowed when converting OpenMP to LLVM Dialect, it carries |
272 | // bounds information for map clauses and the operation and type are |
273 | // discarded on lowering to LLVM-IR from the OpenMP dialect. |
274 | converter.addConversion( |
275 | callback: [&](omp::MapBoundsType type) -> Type { return type; }); |
276 | |
277 | patterns.add< |
278 | AtomicReadOpConversion, MapInfoOpConversion, ReductionOpConversion, |
279 | MultiRegionOpConversion<omp::DeclareReductionOp>, |
280 | MultiRegionOpConversion<omp::PrivateClauseOp>, |
281 | RegionOpConversion<omp::CriticalOp>, RegionOpConversion<omp::LoopNestOp>, |
282 | RegionOpConversion<omp::MasterOp>, ReductionOpConversion, |
283 | RegionOpConversion<omp::OrderedRegionOp>, |
284 | RegionOpConversion<omp::ParallelOp>, RegionOpConversion<omp::WsloopOp>, |
285 | RegionOpConversion<omp::SectionsOp>, RegionOpConversion<omp::SectionOp>, |
286 | RegionOpConversion<omp::SimdOp>, RegionOpConversion<omp::SingleOp>, |
287 | RegionOpConversion<omp::TaskgroupOp>, RegionOpConversion<omp::TaskOp>, |
288 | RegionOpConversion<omp::TargetDataOp>, RegionOpConversion<omp::TargetOp>, |
289 | RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>, |
290 | RegionOpWithVarOperandsConversion<omp::AtomicUpdateOp>, |
291 | RegionLessOpWithVarOperandsConversion<omp::FlushOp>, |
292 | RegionLessOpWithVarOperandsConversion<omp::ThreadprivateOp>, |
293 | RegionLessOpConversion<omp::YieldOp>, |
294 | RegionLessOpConversion<omp::TargetEnterDataOp>, |
295 | RegionLessOpConversion<omp::TargetExitDataOp>, |
296 | RegionLessOpConversion<omp::TargetUpdateOp>, |
297 | RegionLessOpWithVarOperandsConversion<omp::MapBoundsOp>>(converter); |
298 | } |
299 | |
300 | namespace { |
301 | struct ConvertOpenMPToLLVMPass |
302 | : public impl::ConvertOpenMPToLLVMPassBase<ConvertOpenMPToLLVMPass> { |
303 | using Base::Base; |
304 | |
305 | void runOnOperation() override; |
306 | }; |
307 | } // namespace |
308 | |
309 | void ConvertOpenMPToLLVMPass::runOnOperation() { |
310 | auto module = getOperation(); |
311 | |
312 | // Convert to OpenMP operations with LLVM IR dialect |
313 | RewritePatternSet patterns(&getContext()); |
314 | LLVMTypeConverter converter(&getContext()); |
315 | arith::populateArithToLLVMConversionPatterns(converter, patterns); |
316 | cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); |
317 | populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns); |
318 | populateFuncToLLVMConversionPatterns(converter, patterns); |
319 | populateOpenMPToLLVMConversionPatterns(converter, patterns); |
320 | |
321 | LLVMConversionTarget target(getContext()); |
322 | target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp, |
323 | omp::BarrierOp, omp::TaskwaitOp>(); |
324 | configureOpenMPToLLVMConversionLegality(target, typeConverter&: converter); |
325 | if (failed(applyPartialConversion(module, target, std::move(patterns)))) |
326 | signalPassFailure(); |
327 | } |
328 | |