1//===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
10
11#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
12#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
13#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
14#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
15#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
16#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
17#include "mlir/Conversion/LLVMCommon/Pattern.h"
18#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
21#include "mlir/Pass/Pass.h"
22
23namespace mlir {
24#define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS
25#include "mlir/Conversion/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30namespace {
31
32/// A pattern that converts the result and operand types, attributes, and region
33/// arguments of an OpenMP operation to the LLVM dialect.
34///
35/// Attributes are copied verbatim by default, and only translated if they are
36/// type attributes.
37///
38/// Region bodies, if any, are not modified and expected to either be processed
39/// by the conversion infrastructure or already contain ops compatible with LLVM
40/// dialect types.
41template <typename T>
42struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
43 using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
44
45 LogicalResult
46 matchAndRewrite(T op, typename T::Adaptor adaptor,
47 ConversionPatternRewriter &rewriter) const override {
48 // Translate result types.
49 const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
50 SmallVector<Type> resTypes;
51 if (failed(converter->convertTypes(types: op->getResultTypes(), results&: resTypes)))
52 return failure();
53
54 // Translate type attributes.
55 // They are kept unmodified except if they are type attributes.
56 SmallVector<NamedAttribute> convertedAttrs;
57 for (NamedAttribute attr : op->getAttrs()) {
58 if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
59 Type convertedType = converter->convertType(typeAttr.getValue());
60 convertedAttrs.emplace_back(attr.getName(),
61 TypeAttr::get(convertedType));
62 } else {
63 convertedAttrs.push_back(attr);
64 }
65 }
66
67 // Translate operands.
68 SmallVector<Value> convertedOperands;
69 convertedOperands.reserve(op->getNumOperands());
70 for (auto [originalOperand, convertedOperand] :
71 llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
72 if (!originalOperand)
73 return failure();
74
75 // TODO: Revisit whether we need to trigger an error specifically for this
76 // set of operations. Consider removing this check or updating the list.
77 if constexpr (llvm::is_one_of<T, omp::AtomicUpdateOp, omp::AtomicWriteOp,
78 omp::FlushOp, omp::MapBoundsOp,
79 omp::ThreadprivateOp>::value) {
80 if (isa<MemRefType>(originalOperand.getType())) {
81 // TODO: Support memref type in variable operands
82 return rewriter.notifyMatchFailure(op, "memref is not supported yet");
83 }
84 }
85 convertedOperands.push_back(convertedOperand);
86 }
87
88 // Create new operation.
89 auto newOp = rewriter.create<T>(op.getLoc(), resTypes, convertedOperands,
90 convertedAttrs);
91
92 // Translate regions.
93 for (auto [originalRegion, convertedRegion] :
94 llvm::zip_equal(op->getRegions(), newOp->getRegions())) {
95 rewriter.inlineRegionBefore(originalRegion, convertedRegion,
96 convertedRegion.end());
97 if (failed(rewriter.convertRegionTypes(&convertedRegion,
98 *this->getTypeConverter())))
99 return failure();
100 }
101
102 // Delete old operation and replace result uses with those of the new one.
103 rewriter.replaceOp(op, newOp->getResults());
104 return success();
105 }
106};
107
108} // namespace
109
110void mlir::configureOpenMPToLLVMConversionLegality(
111 ConversionTarget &target, const LLVMTypeConverter &typeConverter) {
112 target.addDynamicallyLegalOp<
113#define GET_OP_LIST
114#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
115 >([&](Operation *op) {
116 return typeConverter.isLegal(op->getOperandTypes()) &&
117 typeConverter.isLegal(op->getResultTypes()) &&
118 llvm::all_of(op->getRegions(),
119 [&](Region &region) {
120 return typeConverter.isLegal(&region);
121 }) &&
122 llvm::all_of(op->getAttrs(), [&](NamedAttribute attr) {
123 auto typeAttr = dyn_cast<TypeAttr>(attr.getValue());
124 return !typeAttr || typeConverter.isLegal(typeAttr.getValue());
125 });
126 });
127}
128
129/// Add an `OpenMPOpConversion<T>` conversion pattern for each operation type
130/// passed as template argument.
131template <typename... Ts>
132static inline RewritePatternSet &
133addOpenMPOpConversions(LLVMTypeConverter &converter,
134 RewritePatternSet &patterns) {
135 return patterns.add<OpenMPOpConversion<Ts>...>(converter);
136}
137
138void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
139 RewritePatternSet &patterns) {
140 // This type is allowed when converting OpenMP to LLVM Dialect, it carries
141 // bounds information for map clauses and the operation and type are
142 // discarded on lowering to LLVM-IR from the OpenMP dialect.
143 converter.addConversion(
144 callback: [&](omp::MapBoundsType type) -> Type { return type; });
145
146 // Add conversions for all OpenMP operations.
147 addOpenMPOpConversions<
148#define GET_OP_LIST
149#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
150 >(converter, patterns);
151}
152
153namespace {
154struct ConvertOpenMPToLLVMPass
155 : public impl::ConvertOpenMPToLLVMPassBase<ConvertOpenMPToLLVMPass> {
156 using Base::Base;
157
158 void runOnOperation() override;
159};
160} // namespace
161
162void ConvertOpenMPToLLVMPass::runOnOperation() {
163 auto module = getOperation();
164
165 // Convert to OpenMP operations with LLVM IR dialect
166 RewritePatternSet patterns(&getContext());
167 LLVMTypeConverter converter(&getContext());
168 arith::populateArithToLLVMConversionPatterns(converter, patterns);
169 cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
170 cf::populateAssertToLLVMConversionPattern(converter, patterns);
171 populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
172 populateFuncToLLVMConversionPatterns(converter, patterns);
173 populateOpenMPToLLVMConversionPatterns(converter, patterns);
174
175 LLVMConversionTarget target(getContext());
176 target.addLegalOp<omp::BarrierOp, omp::FlushOp, omp::TaskwaitOp,
177 omp::TaskyieldOp, omp::TerminatorOp>();
178 configureOpenMPToLLVMConversionLegality(target, typeConverter: converter);
179 if (failed(applyPartialConversion(module, target, std::move(patterns))))
180 signalPassFailure();
181}
182
183//===----------------------------------------------------------------------===//
184// ConvertToLLVMPatternInterface implementation
185//===----------------------------------------------------------------------===//
186namespace {
187/// Implement the interface to convert OpenMP to LLVM.
188struct OpenMPToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
189 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
190 void loadDependentDialects(MLIRContext *context) const final {
191 context->loadDialect<LLVM::LLVMDialect>();
192 }
193
194 /// Hook for derived dialect interface to provide conversion patterns
195 /// and mark dialect legal for the conversion target.
196 void populateConvertToLLVMConversionPatterns(
197 ConversionTarget &target, LLVMTypeConverter &typeConverter,
198 RewritePatternSet &patterns) const final {
199 configureOpenMPToLLVMConversionLegality(target, typeConverter);
200 populateOpenMPToLLVMConversionPatterns(converter&: typeConverter, patterns);
201 }
202};
203} // namespace
204
205void mlir::registerConvertOpenMPToLLVMInterface(DialectRegistry &registry) {
206 registry.addExtension(extensionFn: +[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
207 dialect->addInterfaces<OpenMPToLLVMDialectInterface>();
208 });
209}
210

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp