1//===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
10
11#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
12#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
13#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
14#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
15#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16#include "mlir/Conversion/LLVMCommon/Pattern.h"
17#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
18#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
20#include "mlir/Pass/Pass.h"
21
22namespace mlir {
23#define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS
24#include "mlir/Conversion/Passes.h.inc"
25} // namespace mlir
26
27using namespace mlir;
28
29namespace {
30/// A pattern that converts the region arguments in a single-region OpenMP
31/// operation to the LLVM dialect. The body of the region is not modified and is
32/// expected to either be processed by the conversion infrastructure or already
33/// contain ops compatible with LLVM dialect types.
34template <typename OpType>
35struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
36 using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
37
38 LogicalResult
39 matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
40 ConversionPatternRewriter &rewriter) const override {
41 auto newOp = rewriter.create<OpType>(
42 curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
43 rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
44 newOp.getRegion().end());
45 if (failed(rewriter.convertRegionTypes(region: &newOp.getRegion(),
46 converter: *this->getTypeConverter())))
47 return failure();
48
49 rewriter.eraseOp(op: curOp);
50 return success();
51 }
52};
53
54template <typename T>
55struct RegionLessOpWithVarOperandsConversion
56 : public ConvertOpToLLVMPattern<T> {
57 using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
58 LogicalResult
59 matchAndRewrite(T curOp, typename T::Adaptor adaptor,
60 ConversionPatternRewriter &rewriter) const override {
61 const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
62 SmallVector<Type> resTypes;
63 if (failed(converter->convertTypes(types: curOp->getResultTypes(), results&: resTypes)))
64 return failure();
65 SmallVector<Value> convertedOperands;
66 assert(curOp.getNumVariableOperands() ==
67 curOp.getOperation()->getNumOperands() &&
68 "unexpected non-variable operands");
69 for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
70 Value originalVariableOperand = curOp.getVariableOperand(idx);
71 if (!originalVariableOperand)
72 return failure();
73 if (isa<MemRefType>(originalVariableOperand.getType())) {
74 // TODO: Support memref type in variable operands
75 return rewriter.notifyMatchFailure(curOp,
76 "memref is not supported yet");
77 }
78 convertedOperands.emplace_back(adaptor.getOperands()[idx]);
79 }
80
81 rewriter.replaceOpWithNewOp<T>(curOp, resTypes, convertedOperands,
82 curOp->getAttrs());
83 return success();
84 }
85};
86
87template <typename T>
88struct RegionOpWithVarOperandsConversion : public ConvertOpToLLVMPattern<T> {
89 using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
90 LogicalResult
91 matchAndRewrite(T curOp, typename T::Adaptor adaptor,
92 ConversionPatternRewriter &rewriter) const override {
93 const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
94 SmallVector<Type> resTypes;
95 if (failed(converter->convertTypes(types: curOp->getResultTypes(), results&: resTypes)))
96 return failure();
97 SmallVector<Value> convertedOperands;
98 assert(curOp.getNumVariableOperands() ==
99 curOp.getOperation()->getNumOperands() &&
100 "unexpected non-variable operands");
101 for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
102 Value originalVariableOperand = curOp.getVariableOperand(idx);
103 if (!originalVariableOperand)
104 return failure();
105 if (isa<MemRefType>(originalVariableOperand.getType())) {
106 // TODO: Support memref type in variable operands
107 return rewriter.notifyMatchFailure(curOp,
108 "memref is not supported yet");
109 }
110 convertedOperands.emplace_back(adaptor.getOperands()[idx]);
111 }
112 auto newOp = rewriter.create<T>(curOp.getLoc(), resTypes, convertedOperands,
113 curOp->getAttrs());
114 rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
115 newOp.getRegion().end());
116 if (failed(rewriter.convertRegionTypes(region: &newOp.getRegion(),
117 converter: *this->getTypeConverter())))
118 return failure();
119
120 rewriter.eraseOp(op: curOp);
121 return success();
122 }
123};
124
125template <typename T>
126struct RegionLessOpConversion : public ConvertOpToLLVMPattern<T> {
127 using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
128 LogicalResult
129 matchAndRewrite(T curOp, typename T::Adaptor adaptor,
130 ConversionPatternRewriter &rewriter) const override {
131 const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
132 SmallVector<Type> resTypes;
133 if (failed(converter->convertTypes(types: curOp->getResultTypes(), results&: resTypes)))
134 return failure();
135
136 rewriter.replaceOpWithNewOp<T>(curOp, resTypes, adaptor.getOperands(),
137 curOp->getAttrs());
138 return success();
139 }
140};
141
142struct AtomicReadOpConversion
143 : public ConvertOpToLLVMPattern<omp::AtomicReadOp> {
144 using ConvertOpToLLVMPattern<omp::AtomicReadOp>::ConvertOpToLLVMPattern;
145 LogicalResult
146 matchAndRewrite(omp::AtomicReadOp curOp, OpAdaptor adaptor,
147 ConversionPatternRewriter &rewriter) const override {
148 const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
149 Type curElementType = curOp.getElementType();
150 auto newOp = rewriter.create<omp::AtomicReadOp>(
151 curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
152 TypeAttr typeAttr = TypeAttr::get(converter->convertType(curElementType));
153 newOp.setElementTypeAttr(typeAttr);
154 rewriter.eraseOp(op: curOp);
155 return success();
156 }
157};
158
159struct MapInfoOpConversion : public ConvertOpToLLVMPattern<omp::MapInfoOp> {
160 using ConvertOpToLLVMPattern<omp::MapInfoOp>::ConvertOpToLLVMPattern;
161 LogicalResult
162 matchAndRewrite(omp::MapInfoOp curOp, OpAdaptor adaptor,
163 ConversionPatternRewriter &rewriter) const override {
164 const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
165
166 SmallVector<Type> resTypes;
167 if (failed(converter->convertTypes(types: curOp->getResultTypes(), results&: resTypes)))
168 return failure();
169
170 // Copy attributes of the curOp except for the typeAttr which should
171 // be converted
172 SmallVector<NamedAttribute> newAttrs;
173 for (NamedAttribute attr : curOp->getAttrs()) {
174 if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
175 Type newAttr = converter->convertType(typeAttr.getValue());
176 newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
177 } else {
178 newAttrs.push_back(attr);
179 }
180 }
181
182 rewriter.replaceOpWithNewOp<omp::MapInfoOp>(
183 curOp, resTypes, adaptor.getOperands(), newAttrs);
184 return success();
185 }
186};
187
188struct ReductionOpConversion : public ConvertOpToLLVMPattern<omp::ReductionOp> {
189 using ConvertOpToLLVMPattern<omp::ReductionOp>::ConvertOpToLLVMPattern;
190 LogicalResult
191 matchAndRewrite(omp::ReductionOp curOp, OpAdaptor adaptor,
192 ConversionPatternRewriter &rewriter) const override {
193 if (isa<MemRefType>(curOp.getAccumulator().getType())) {
194 // TODO: Support memref type in variable operands
195 return rewriter.notifyMatchFailure(curOp, "memref is not supported yet");
196 }
197 rewriter.replaceOpWithNewOp<omp::ReductionOp>(
198 curOp, TypeRange(), adaptor.getOperands(), curOp->getAttrs());
199 return success();
200 }
201};
202
203template <typename OpType>
204struct MultiRegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
205 using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
206
207 void forwardOpAttrs(OpType curOp, OpType newOp) const {}
208
209 LogicalResult
210 matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
211 ConversionPatternRewriter &rewriter) const override {
212 auto newOp = rewriter.create<OpType>(
213 curOp.getLoc(), TypeRange(), curOp.getSymNameAttr(),
214 TypeAttr::get(this->getTypeConverter()->convertType(
215 curOp.getTypeAttr().getValue())));
216 forwardOpAttrs(curOp, newOp: newOp);
217
218 for (unsigned idx = 0; idx < curOp.getNumRegions(); idx++) {
219 rewriter.inlineRegionBefore(curOp.getRegion(idx), newOp.getRegion(idx),
220 newOp.getRegion(idx).end());
221 if (failed(rewriter.convertRegionTypes(region: &newOp.getRegion(idx),
222 converter: *this->getTypeConverter())))
223 return failure();
224 }
225
226 rewriter.eraseOp(op: curOp);
227 return success();
228 }
229};
230
231template <>
232void MultiRegionOpConversion<omp::PrivateClauseOp>::forwardOpAttrs(
233 omp::PrivateClauseOp curOp, omp::PrivateClauseOp newOp) const {
234 newOp.setDataSharingType(curOp.getDataSharingType());
235}
236} // namespace
237
238void mlir::configureOpenMPToLLVMConversionLegality(
239 ConversionTarget &target, LLVMTypeConverter &typeConverter) {
240 target.addDynamicallyLegalOp<
241 mlir::omp::AtomicReadOp, mlir::omp::AtomicWriteOp, mlir::omp::FlushOp,
242 mlir::omp::ThreadprivateOp, mlir::omp::YieldOp,
243 mlir::omp::TargetEnterDataOp, mlir::omp::TargetExitDataOp,
244 mlir::omp::TargetUpdateOp, mlir::omp::MapBoundsOp, mlir::omp::MapInfoOp>(
245 [&](Operation *op) {
246 return typeConverter.isLegal(op->getOperandTypes()) &&
247 typeConverter.isLegal(op->getResultTypes());
248 });
249 target.addDynamicallyLegalOp<mlir::omp::ReductionOp>([&](Operation *op) {
250 return typeConverter.isLegal(op->getOperandTypes());
251 });
252 target.addDynamicallyLegalOp<
253 mlir::omp::AtomicUpdateOp, mlir::omp::CriticalOp, mlir::omp::TargetOp,
254 mlir::omp::TargetDataOp, mlir::omp::LoopNestOp,
255 mlir::omp::OrderedRegionOp, mlir::omp::ParallelOp, mlir::omp::WsloopOp,
256 mlir::omp::SimdOp, mlir::omp::MasterOp, mlir::omp::SectionOp,
257 mlir::omp::SectionsOp, mlir::omp::SingleOp, mlir::omp::TaskgroupOp,
258 mlir::omp::TaskOp, mlir::omp::DeclareReductionOp,
259 mlir::omp::PrivateClauseOp>([&](Operation *op) {
260 return std::all_of(op->getRegions().begin(), op->getRegions().end(),
261 [&](Region &region) {
262 return typeConverter.isLegal(&region);
263 }) &&
264 typeConverter.isLegal(op->getOperandTypes()) &&
265 typeConverter.isLegal(op->getResultTypes());
266 });
267}
268
269void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
270 RewritePatternSet &patterns) {
271 // This type is allowed when converting OpenMP to LLVM Dialect, it carries
272 // bounds information for map clauses and the operation and type are
273 // discarded on lowering to LLVM-IR from the OpenMP dialect.
274 converter.addConversion(
275 callback: [&](omp::MapBoundsType type) -> Type { return type; });
276
277 patterns.add<
278 AtomicReadOpConversion, MapInfoOpConversion, ReductionOpConversion,
279 MultiRegionOpConversion<omp::DeclareReductionOp>,
280 MultiRegionOpConversion<omp::PrivateClauseOp>,
281 RegionOpConversion<omp::CriticalOp>, RegionOpConversion<omp::LoopNestOp>,
282 RegionOpConversion<omp::MasterOp>, ReductionOpConversion,
283 RegionOpConversion<omp::OrderedRegionOp>,
284 RegionOpConversion<omp::ParallelOp>, RegionOpConversion<omp::WsloopOp>,
285 RegionOpConversion<omp::SectionsOp>, RegionOpConversion<omp::SectionOp>,
286 RegionOpConversion<omp::SimdOp>, RegionOpConversion<omp::SingleOp>,
287 RegionOpConversion<omp::TaskgroupOp>, RegionOpConversion<omp::TaskOp>,
288 RegionOpConversion<omp::TargetDataOp>, RegionOpConversion<omp::TargetOp>,
289 RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>,
290 RegionOpWithVarOperandsConversion<omp::AtomicUpdateOp>,
291 RegionLessOpWithVarOperandsConversion<omp::FlushOp>,
292 RegionLessOpWithVarOperandsConversion<omp::ThreadprivateOp>,
293 RegionLessOpConversion<omp::YieldOp>,
294 RegionLessOpConversion<omp::TargetEnterDataOp>,
295 RegionLessOpConversion<omp::TargetExitDataOp>,
296 RegionLessOpConversion<omp::TargetUpdateOp>,
297 RegionLessOpWithVarOperandsConversion<omp::MapBoundsOp>>(converter);
298}
299
300namespace {
301struct ConvertOpenMPToLLVMPass
302 : public impl::ConvertOpenMPToLLVMPassBase<ConvertOpenMPToLLVMPass> {
303 using Base::Base;
304
305 void runOnOperation() override;
306};
307} // namespace
308
309void ConvertOpenMPToLLVMPass::runOnOperation() {
310 auto module = getOperation();
311
312 // Convert to OpenMP operations with LLVM IR dialect
313 RewritePatternSet patterns(&getContext());
314 LLVMTypeConverter converter(&getContext());
315 arith::populateArithToLLVMConversionPatterns(converter, patterns);
316 cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
317 populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
318 populateFuncToLLVMConversionPatterns(converter, patterns);
319 populateOpenMPToLLVMConversionPatterns(converter, patterns);
320
321 LLVMConversionTarget target(getContext());
322 target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp,
323 omp::BarrierOp, omp::TaskwaitOp>();
324 configureOpenMPToLLVMConversionLegality(target, typeConverter&: converter);
325 if (failed(applyPartialConversion(module, target, std::move(patterns))))
326 signalPassFailure();
327}
328

source code of mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp