| 1 | //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements a pass to convert MLIR standard and builtin dialects |
| 10 | // into the LLVM IR dialect. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" |
| 15 | |
| 16 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
| 17 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
| 18 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
| 19 | #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h" |
| 20 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
| 21 | #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" |
| 22 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 23 | #include "mlir/IR/BuiltinOps.h" |
| 24 | #include "mlir/IR/PatternMatch.h" |
| 25 | #include "mlir/Pass/Pass.h" |
| 26 | #include "mlir/Transforms/DialectConversion.h" |
| 27 | |
| 28 | namespace mlir { |
| 29 | #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS |
| 30 | #include "mlir/Conversion/Passes.h.inc" |
| 31 | } // namespace mlir |
| 32 | |
| 33 | using namespace mlir; |
| 34 | |
| 35 | #define PASS_NAME "convert-cf-to-llvm" |
| 36 | |
| 37 | namespace { |
| 38 | /// Lower `cf.assert`. The default lowering calls the `abort` function if the |
| 39 | /// assertion is violated and has no effect otherwise. The failure message is |
| 40 | /// ignored by the default lowering but should be propagated by any custom |
| 41 | /// lowering. |
| 42 | struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> { |
| 43 | explicit AssertOpLowering(const LLVMTypeConverter &typeConverter, |
| 44 | bool abortOnFailedAssert = true, |
| 45 | SymbolTableCollection *symbolTables = nullptr) |
| 46 | : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1), |
| 47 | abortOnFailedAssert(abortOnFailedAssert), symbolTables(symbolTables) {} |
| 48 | |
| 49 | LogicalResult |
| 50 | matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, |
| 51 | ConversionPatternRewriter &rewriter) const override { |
| 52 | auto loc = op.getLoc(); |
| 53 | auto module = op->getParentOfType<ModuleOp>(); |
| 54 | |
| 55 | // Split block at `assert` operation. |
| 56 | Block *opBlock = rewriter.getInsertionBlock(); |
| 57 | auto opPosition = rewriter.getInsertionPoint(); |
| 58 | Block *continuationBlock = rewriter.splitBlock(block: opBlock, before: opPosition); |
| 59 | |
| 60 | // Failed block: Generate IR to print the message and call `abort`. |
| 61 | Block *failureBlock = rewriter.createBlock(parent: opBlock->getParent()); |
| 62 | auto createResult = LLVM::createPrintStrCall( |
| 63 | builder&: rewriter, loc, moduleOp: module, symbolName: "assert_msg" , string: op.getMsg(), typeConverter: *getTypeConverter(), |
| 64 | /*addNewLine=*/addNewline: false, |
| 65 | /*runtimeFunctionName=*/"puts" , symbolTables); |
| 66 | if (createResult.failed()) |
| 67 | return failure(); |
| 68 | |
| 69 | if (abortOnFailedAssert) { |
| 70 | // Insert the `abort` declaration if necessary. |
| 71 | auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(name: "abort" ); |
| 72 | if (!abortFunc) { |
| 73 | OpBuilder::InsertionGuard guard(rewriter); |
| 74 | rewriter.setInsertionPointToStart(module.getBody()); |
| 75 | auto abortFuncTy = LLVM::LLVMFunctionType::get(result: getVoidType(), arguments: {}); |
| 76 | abortFunc = rewriter.create<LLVM::LLVMFuncOp>(location: rewriter.getUnknownLoc(), |
| 77 | args: "abort" , args&: abortFuncTy); |
| 78 | } |
| 79 | rewriter.create<LLVM::CallOp>(location: loc, args&: abortFunc, args: ValueRange()); |
| 80 | rewriter.create<LLVM::UnreachableOp>(location: loc); |
| 81 | } else { |
| 82 | rewriter.create<LLVM::BrOp>(location: loc, args: ValueRange(), args&: continuationBlock); |
| 83 | } |
| 84 | |
| 85 | // Generate assertion test. |
| 86 | rewriter.setInsertionPointToEnd(opBlock); |
| 87 | rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( |
| 88 | op, args: adaptor.getArg(), args&: continuationBlock, args&: failureBlock); |
| 89 | |
| 90 | return success(); |
| 91 | } |
| 92 | |
| 93 | private: |
| 94 | /// If set to `false`, messages are printed but program execution continues. |
| 95 | /// This is useful for testing asserts. |
| 96 | bool abortOnFailedAssert = true; |
| 97 | |
| 98 | SymbolTableCollection *symbolTables = nullptr; |
| 99 | }; |
| 100 | |
| 101 | /// Helper function for converting branch ops. This function converts the |
| 102 | /// signature of the given block. If the new block signature is different from |
| 103 | /// `expectedTypes`, returns "failure". |
| 104 | static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter, |
| 105 | const TypeConverter *converter, |
| 106 | Operation *branchOp, Block *block, |
| 107 | TypeRange expectedTypes) { |
| 108 | assert(converter && "expected non-null type converter" ); |
| 109 | assert(!block->isEntryBlock() && "entry blocks have no predecessors" ); |
| 110 | |
| 111 | // There is nothing to do if the types already match. |
| 112 | if (block->getArgumentTypes() == expectedTypes) |
| 113 | return block; |
| 114 | |
| 115 | // Compute the new block argument types and convert the block. |
| 116 | std::optional<TypeConverter::SignatureConversion> conversion = |
| 117 | converter->convertBlockSignature(block); |
| 118 | if (!conversion) |
| 119 | return rewriter.notifyMatchFailure(arg&: branchOp, |
| 120 | msg: "could not compute block signature" ); |
| 121 | if (expectedTypes != conversion->getConvertedTypes()) |
| 122 | return rewriter.notifyMatchFailure( |
| 123 | arg&: branchOp, |
| 124 | msg: "mismatch between adaptor operand types and computed block signature" ); |
| 125 | return rewriter.applySignatureConversion(block, conversion&: *conversion, converter); |
| 126 | } |
| 127 | |
| 128 | /// Convert the destination block signature (if necessary) and lower the branch |
| 129 | /// op to llvm.br. |
| 130 | struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { |
| 131 | using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern; |
| 132 | |
| 133 | LogicalResult |
| 134 | matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor, |
| 135 | ConversionPatternRewriter &rewriter) const override { |
| 136 | FailureOr<Block *> convertedBlock = |
| 137 | getConvertedBlock(rewriter, converter: getTypeConverter(), branchOp: op, block: op.getSuccessor(), |
| 138 | expectedTypes: TypeRange(adaptor.getOperands())); |
| 139 | if (failed(Result: convertedBlock)) |
| 140 | return failure(); |
| 141 | DictionaryAttr attrs = op->getAttrDictionary(); |
| 142 | Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( |
| 143 | op, args: adaptor.getOperands(), args&: *convertedBlock); |
| 144 | // TODO: We should not just forward all attributes like that. But there are |
| 145 | // existing Flang tests that depend on this behavior. |
| 146 | newOp->setAttrs(attrs); |
| 147 | return success(); |
| 148 | } |
| 149 | }; |
| 150 | |
| 151 | /// Convert the destination block signatures (if necessary) and lower the |
| 152 | /// branch op to llvm.cond_br. |
| 153 | struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> { |
| 154 | using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern; |
| 155 | |
| 156 | LogicalResult |
| 157 | matchAndRewrite(cf::CondBranchOp op, |
| 158 | typename cf::CondBranchOp::Adaptor adaptor, |
| 159 | ConversionPatternRewriter &rewriter) const override { |
| 160 | FailureOr<Block *> convertedTrueBlock = |
| 161 | getConvertedBlock(rewriter, converter: getTypeConverter(), branchOp: op, block: op.getTrueDest(), |
| 162 | expectedTypes: TypeRange(adaptor.getTrueDestOperands())); |
| 163 | if (failed(Result: convertedTrueBlock)) |
| 164 | return failure(); |
| 165 | FailureOr<Block *> convertedFalseBlock = |
| 166 | getConvertedBlock(rewriter, converter: getTypeConverter(), branchOp: op, block: op.getFalseDest(), |
| 167 | expectedTypes: TypeRange(adaptor.getFalseDestOperands())); |
| 168 | if (failed(Result: convertedFalseBlock)) |
| 169 | return failure(); |
| 170 | DictionaryAttr attrs = op->getAttrDictionary(); |
| 171 | auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( |
| 172 | op, args: adaptor.getCondition(), args: adaptor.getTrueDestOperands(), |
| 173 | args: adaptor.getFalseDestOperands(), args: op.getBranchWeightsAttr(), |
| 174 | args&: *convertedTrueBlock, args&: *convertedFalseBlock); |
| 175 | // TODO: We should not just forward all attributes like that. But there are |
| 176 | // existing Flang tests that depend on this behavior. |
| 177 | newOp->setAttrs(attrs); |
| 178 | return success(); |
| 179 | } |
| 180 | }; |
| 181 | |
| 182 | /// Convert the destination block signatures (if necessary) and lower the |
| 183 | /// switch op to llvm.switch. |
| 184 | struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> { |
| 185 | using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern; |
| 186 | |
| 187 | LogicalResult |
| 188 | matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor, |
| 189 | ConversionPatternRewriter &rewriter) const override { |
| 190 | // Get or convert default block. |
| 191 | FailureOr<Block *> convertedDefaultBlock = getConvertedBlock( |
| 192 | rewriter, converter: getTypeConverter(), branchOp: op, block: op.getDefaultDestination(), |
| 193 | expectedTypes: TypeRange(adaptor.getDefaultOperands())); |
| 194 | if (failed(Result: convertedDefaultBlock)) |
| 195 | return failure(); |
| 196 | |
| 197 | // Get or convert all case blocks. |
| 198 | SmallVector<Block *> caseDestinations; |
| 199 | SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands(); |
| 200 | for (auto it : llvm::enumerate(First: op.getCaseDestinations())) { |
| 201 | Block *b = it.value(); |
| 202 | FailureOr<Block *> convertedBlock = |
| 203 | getConvertedBlock(rewriter, converter: getTypeConverter(), branchOp: op, block: b, |
| 204 | expectedTypes: TypeRange(caseOperands[it.index()])); |
| 205 | if (failed(Result: convertedBlock)) |
| 206 | return failure(); |
| 207 | caseDestinations.push_back(Elt: *convertedBlock); |
| 208 | } |
| 209 | |
| 210 | rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( |
| 211 | op, args: adaptor.getFlag(), args&: *convertedDefaultBlock, |
| 212 | args: adaptor.getDefaultOperands(), args: adaptor.getCaseValuesAttr(), |
| 213 | args&: caseDestinations, args&: caseOperands); |
| 214 | return success(); |
| 215 | } |
| 216 | }; |
| 217 | |
| 218 | } // namespace |
| 219 | |
| 220 | void mlir::cf::populateControlFlowToLLVMConversionPatterns( |
| 221 | const LLVMTypeConverter &converter, RewritePatternSet &patterns) { |
| 222 | // clang-format off |
| 223 | patterns.add< |
| 224 | BranchOpLowering, |
| 225 | CondBranchOpLowering, |
| 226 | SwitchOpLowering>(arg: converter); |
| 227 | // clang-format on |
| 228 | } |
| 229 | |
| 230 | void mlir::cf::populateAssertToLLVMConversionPattern( |
| 231 | const LLVMTypeConverter &converter, RewritePatternSet &patterns, |
| 232 | bool abortOnFailure, SymbolTableCollection *symbolTables) { |
| 233 | patterns.add<AssertOpLowering>(arg: converter, args&: abortOnFailure, args&: symbolTables); |
| 234 | } |
| 235 | |
| 236 | //===----------------------------------------------------------------------===// |
| 237 | // Pass Definition |
| 238 | //===----------------------------------------------------------------------===// |
| 239 | |
| 240 | namespace { |
| 241 | /// A pass converting MLIR operations into the LLVM IR dialect. |
| 242 | struct ConvertControlFlowToLLVM |
| 243 | : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> { |
| 244 | |
| 245 | using Base::Base; |
| 246 | |
| 247 | /// Run the dialect converter on the module. |
| 248 | void runOnOperation() override { |
| 249 | MLIRContext *ctx = &getContext(); |
| 250 | LLVMConversionTarget target(*ctx); |
| 251 | // This pass lowers only CF dialect ops, but it also modifies block |
| 252 | // signatures inside other ops. These ops should be treated as legal. They |
| 253 | // are lowered by other passes. |
| 254 | target.markUnknownOpDynamicallyLegal(fn: [&](Operation *op) { |
| 255 | return op->getDialect() != |
| 256 | ctx->getLoadedDialect<cf::ControlFlowDialect>(); |
| 257 | }); |
| 258 | |
| 259 | LowerToLLVMOptions options(ctx); |
| 260 | if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) |
| 261 | options.overrideIndexBitwidth(bitwidth: indexBitwidth); |
| 262 | |
| 263 | LLVMTypeConverter converter(ctx, options); |
| 264 | RewritePatternSet patterns(ctx); |
| 265 | mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); |
| 266 | mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns); |
| 267 | |
| 268 | if (failed(Result: applyPartialConversion(op: getOperation(), target, |
| 269 | patterns: std::move(patterns)))) |
| 270 | signalPassFailure(); |
| 271 | } |
| 272 | }; |
| 273 | } // namespace |
| 274 | |
| 275 | //===----------------------------------------------------------------------===// |
| 276 | // ConvertToLLVMPatternInterface implementation |
| 277 | //===----------------------------------------------------------------------===// |
| 278 | |
| 279 | namespace { |
| 280 | /// Implement the interface to convert MemRef to LLVM. |
| 281 | struct ControlFlowToLLVMDialectInterface |
| 282 | : public ConvertToLLVMPatternInterface { |
| 283 | using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; |
| 284 | void loadDependentDialects(MLIRContext *context) const final { |
| 285 | context->loadDialect<LLVM::LLVMDialect>(); |
| 286 | } |
| 287 | |
| 288 | /// Hook for derived dialect interface to provide conversion patterns |
| 289 | /// and mark dialect legal for the conversion target. |
| 290 | void populateConvertToLLVMConversionPatterns( |
| 291 | ConversionTarget &target, LLVMTypeConverter &typeConverter, |
| 292 | RewritePatternSet &patterns) const final { |
| 293 | mlir::cf::populateControlFlowToLLVMConversionPatterns(converter: typeConverter, |
| 294 | patterns); |
| 295 | mlir::cf::populateAssertToLLVMConversionPattern(converter: typeConverter, patterns); |
| 296 | } |
| 297 | }; |
| 298 | } // namespace |
| 299 | |
| 300 | void mlir::cf::registerConvertControlFlowToLLVMInterface( |
| 301 | DialectRegistry ®istry) { |
| 302 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) { |
| 303 | dialect->addInterfaces<ControlFlowToLLVMDialectInterface>(); |
| 304 | }); |
| 305 | } |
| 306 | |