1//===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
10
11#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
12#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
13#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
14#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
15#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16#include "mlir/Conversion/LLVMCommon/Pattern.h"
17#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
18#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
20#include "mlir/Pass/Pass.h"
21
22namespace mlir {
23#define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS
24#include "mlir/Conversion/Passes.h.inc"
25} // namespace mlir
26
27using namespace mlir;
28
29namespace {
30
31/// A pattern that converts the result and operand types, attributes, and region
32/// arguments of an OpenMP operation to the LLVM dialect.
33///
34/// Attributes are copied verbatim by default, and only translated if they are
35/// type attributes.
36///
37/// Region bodies, if any, are not modified and expected to either be processed
38/// by the conversion infrastructure or already contain ops compatible with LLVM
39/// dialect types.
40template <typename T>
41struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
42 using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
43
44 OpenMPOpConversion(LLVMTypeConverter &typeConverter,
45 PatternBenefit benefit = 1)
46 : ConvertOpToLLVMPattern<T>(typeConverter, benefit) {
47 // Operations using CanonicalLoopInfoType are lowered only by
48 // mlir::translateModuleToLLVMIR() using the OpenMPIRBuilder. Until then,
49 // the type and operations using it must be preserved.
50 typeConverter.addConversion(
51 [&](::mlir::omp::CanonicalLoopInfoType type) { return type; });
52 }
53
54 LogicalResult
55 matchAndRewrite(T op, typename T::Adaptor adaptor,
56 ConversionPatternRewriter &rewriter) const override {
57 // Translate result types.
58 const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
59 SmallVector<Type> resTypes;
60 if (failed(converter->convertTypes(types: op->getResultTypes(), results&: resTypes)))
61 return failure();
62
63 // Translate type attributes.
64 // They are kept unmodified except if they are type attributes.
65 SmallVector<NamedAttribute> convertedAttrs;
66 for (NamedAttribute attr : op->getAttrs()) {
67 if (auto typeAttr = dyn_cast<TypeAttr>(Val: attr.getValue())) {
68 Type convertedType = converter->convertType(t: typeAttr.getValue());
69 convertedAttrs.emplace_back(Args: attr.getName(),
70 Args: TypeAttr::get(type: convertedType));
71 } else {
72 convertedAttrs.push_back(Elt: attr);
73 }
74 }
75
76 // Translate operands.
77 SmallVector<Value> convertedOperands;
78 convertedOperands.reserve(N: op->getNumOperands());
79 for (auto [originalOperand, convertedOperand] :
80 llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
81 if (!originalOperand)
82 return failure();
83
84 // TODO: Revisit whether we need to trigger an error specifically for this
85 // set of operations. Consider removing this check or updating the list.
86 if constexpr (llvm::is_one_of<T, omp::AtomicUpdateOp, omp::AtomicWriteOp,
87 omp::FlushOp, omp::MapBoundsOp,
88 omp::ThreadprivateOp>::value) {
89 if (isa<MemRefType>(originalOperand.getType())) {
90 // TODO: Support memref type in variable operands
91 return rewriter.notifyMatchFailure(op, "memref is not supported yet");
92 }
93 }
94 convertedOperands.push_back(Elt: convertedOperand);
95 }
96
97 // Create new operation.
98 auto newOp = rewriter.create<T>(op.getLoc(), resTypes, convertedOperands,
99 convertedAttrs);
100
101 // Translate regions.
102 for (auto [originalRegion, convertedRegion] :
103 llvm::zip_equal(op->getRegions(), newOp->getRegions())) {
104 rewriter.inlineRegionBefore(originalRegion, convertedRegion,
105 convertedRegion.end());
106 if (failed(rewriter.convertRegionTypes(region: &convertedRegion,
107 converter: *this->getTypeConverter())))
108 return failure();
109 }
110
111 // Delete old operation and replace result uses with those of the new one.
112 rewriter.replaceOp(op, newOp->getResults());
113 return success();
114 }
115};
116
117} // namespace
118
119void mlir::configureOpenMPToLLVMConversionLegality(
120 ConversionTarget &target, const LLVMTypeConverter &typeConverter) {
121 target.addDynamicallyLegalOp<
122#define GET_OP_LIST
123#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
124 >(callback: [&](Operation *op) {
125 return typeConverter.isLegal(range: op->getOperandTypes()) &&
126 typeConverter.isLegal(range: op->getResultTypes()) &&
127 llvm::all_of(Range: op->getRegions(),
128 P: [&](Region &region) {
129 return typeConverter.isLegal(region: &region);
130 }) &&
131 llvm::all_of(Range: op->getAttrs(), P: [&](NamedAttribute attr) {
132 auto typeAttr = dyn_cast<TypeAttr>(Val: attr.getValue());
133 return !typeAttr || typeConverter.isLegal(type: typeAttr.getValue());
134 });
135 });
136}
137
138/// Add an `OpenMPOpConversion<T>` conversion pattern for each operation type
139/// passed as template argument.
140template <typename... Ts>
141static inline RewritePatternSet &
142addOpenMPOpConversions(LLVMTypeConverter &converter,
143 RewritePatternSet &patterns) {
144 return patterns.add<OpenMPOpConversion<Ts>...>(converter);
145}
146
147void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
148 RewritePatternSet &patterns) {
149 // This type is allowed when converting OpenMP to LLVM Dialect, it carries
150 // bounds information for map clauses and the operation and type are
151 // discarded on lowering to LLVM-IR from the OpenMP dialect.
152 converter.addConversion(
153 callback: [&](omp::MapBoundsType type) -> Type { return type; });
154
155 // Add conversions for all OpenMP operations.
156 addOpenMPOpConversions<
157#define GET_OP_LIST
158#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
159 >(converter, patterns);
160}
161
162namespace {
163struct ConvertOpenMPToLLVMPass
164 : public impl::ConvertOpenMPToLLVMPassBase<ConvertOpenMPToLLVMPass> {
165 using Base::Base;
166
167 void runOnOperation() override;
168};
169} // namespace
170
171void ConvertOpenMPToLLVMPass::runOnOperation() {
172 auto module = getOperation();
173
174 // Convert to OpenMP operations with LLVM IR dialect
175 RewritePatternSet patterns(&getContext());
176 LLVMTypeConverter converter(&getContext());
177 arith::populateArithToLLVMConversionPatterns(converter, patterns);
178 cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
179 cf::populateAssertToLLVMConversionPattern(converter, patterns);
180 populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
181 populateFuncToLLVMConversionPatterns(converter, patterns);
182 populateOpenMPToLLVMConversionPatterns(converter, patterns);
183
184 LLVMConversionTarget target(getContext());
185 target.addLegalOp<omp::BarrierOp, omp::FlushOp, omp::TaskwaitOp,
186 omp::TaskyieldOp, omp::TerminatorOp>();
187 configureOpenMPToLLVMConversionLegality(target, typeConverter: converter);
188 if (failed(Result: applyPartialConversion(op: module, target, patterns: std::move(patterns))))
189 signalPassFailure();
190}
191
192//===----------------------------------------------------------------------===//
193// ConvertToLLVMPatternInterface implementation
194//===----------------------------------------------------------------------===//
195namespace {
196/// Implement the interface to convert OpenMP to LLVM.
197struct OpenMPToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
198 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
199 void loadDependentDialects(MLIRContext *context) const final {
200 context->loadDialect<LLVM::LLVMDialect>();
201 }
202
203 /// Hook for derived dialect interface to provide conversion patterns
204 /// and mark dialect legal for the conversion target.
205 void populateConvertToLLVMConversionPatterns(
206 ConversionTarget &target, LLVMTypeConverter &typeConverter,
207 RewritePatternSet &patterns) const final {
208 configureOpenMPToLLVMConversionLegality(target, typeConverter);
209 populateOpenMPToLLVMConversionPatterns(converter&: typeConverter, patterns);
210 }
211};
212} // namespace
213
214void mlir::registerConvertOpenMPToLLVMInterface(DialectRegistry &registry) {
215 registry.addExtension(extensionFn: +[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
216 dialect->addInterfaces<OpenMPToLLVMDialectInterface>();
217 });
218}
219

source code of mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp