| 1 | //===-- ControlFlowConverter.cpp ------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
| 10 | #include "flang/Optimizer/Dialect/FIROps.h" |
| 11 | #include "flang/Optimizer/Dialect/FIROpsSupport.h" |
| 12 | #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
| 13 | #include "flang/Optimizer/Dialect/Support/KindMapping.h" |
| 14 | #include "flang/Optimizer/Support/InternalNames.h" |
| 15 | #include "flang/Optimizer/Support/TypeCode.h" |
| 16 | #include "flang/Optimizer/Transforms/Passes.h" |
| 17 | #include "flang/Runtime/derived-api.h" |
| 18 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 19 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
| 20 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 21 | #include "mlir/Pass/Pass.h" |
| 22 | #include "mlir/Transforms/DialectConversion.h" |
| 23 | #include "llvm/ADT/SmallSet.h" |
| 24 | #include "llvm/Support/CommandLine.h" |
| 25 | |
| 26 | namespace fir { |
| 27 | #define GEN_PASS_DEF_CFGCONVERSION |
| 28 | #include "flang/Optimizer/Transforms/Passes.h.inc" |
| 29 | } // namespace fir |
| 30 | |
| 31 | using namespace fir; |
| 32 | using namespace mlir; |
| 33 | |
| 34 | namespace { |
| 35 | |
| 36 | // Conversion of fir control ops to more primitive control-flow. |
| 37 | // |
| 38 | // FIR loops that cannot be converted to the affine dialect will remain as |
| 39 | // `fir.do_loop` operations. These can be converted to control-flow operations. |
| 40 | |
| 41 | /// Convert `fir.do_loop` to CFG |
| 42 | class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> { |
| 43 | public: |
| 44 | using OpRewritePattern::OpRewritePattern; |
| 45 | |
| 46 | CfgLoopConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce, bool setNSW) |
| 47 | : mlir::OpRewritePattern<fir::DoLoopOp>(ctx), |
| 48 | forceLoopToExecuteOnce(forceLoopToExecuteOnce), setNSW(setNSW) {} |
| 49 | |
| 50 | llvm::LogicalResult |
| 51 | matchAndRewrite(DoLoopOp loop, |
| 52 | mlir::PatternRewriter &rewriter) const override { |
| 53 | auto loc = loop.getLoc(); |
| 54 | mlir::arith::IntegerOverflowFlags flags{}; |
| 55 | if (setNSW) |
| 56 | flags = bitEnumSet(flags, mlir::arith::IntegerOverflowFlags::nsw); |
| 57 | auto iofAttr = mlir::arith::IntegerOverflowFlagsAttr::get( |
| 58 | rewriter.getContext(), flags); |
| 59 | |
| 60 | // Create the start and end blocks that will wrap the DoLoopOp with an |
| 61 | // initalizer and an end point |
| 62 | auto *initBlock = rewriter.getInsertionBlock(); |
| 63 | auto initPos = rewriter.getInsertionPoint(); |
| 64 | auto *endBlock = rewriter.splitBlock(initBlock, initPos); |
| 65 | |
| 66 | // Split the first DoLoopOp block in two parts. The part before will be the |
| 67 | // conditional block since it already has the induction variable and |
| 68 | // loop-carried values as arguments. |
| 69 | auto *conditionalBlock = &loop.getRegion().front(); |
| 70 | conditionalBlock->addArgument(rewriter.getIndexType(), loc); |
| 71 | auto *firstBlock = |
| 72 | rewriter.splitBlock(conditionalBlock, conditionalBlock->begin()); |
| 73 | auto *lastBlock = &loop.getRegion().back(); |
| 74 | |
| 75 | // Move the blocks from the DoLoopOp between initBlock and endBlock |
| 76 | rewriter.inlineRegionBefore(loop.getRegion(), endBlock); |
| 77 | |
| 78 | // Get loop values from the DoLoopOp |
| 79 | auto low = loop.getLowerBound(); |
| 80 | auto high = loop.getUpperBound(); |
| 81 | assert(low && high && "must be a Value" ); |
| 82 | auto step = loop.getStep(); |
| 83 | |
| 84 | // Initalization block |
| 85 | rewriter.setInsertionPointToEnd(initBlock); |
| 86 | auto diff = rewriter.create<mlir::arith::SubIOp>(loc, high, low); |
| 87 | auto distance = rewriter.create<mlir::arith::AddIOp>(loc, diff, step); |
| 88 | mlir::Value iters = |
| 89 | rewriter.create<mlir::arith::DivSIOp>(loc, distance, step); |
| 90 | |
| 91 | if (forceLoopToExecuteOnce) { |
| 92 | auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); |
| 93 | auto cond = rewriter.create<mlir::arith::CmpIOp>( |
| 94 | loc, arith::CmpIPredicate::sle, iters, zero); |
| 95 | auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); |
| 96 | iters = rewriter.create<mlir::arith::SelectOp>(loc, cond, one, iters); |
| 97 | } |
| 98 | |
| 99 | llvm::SmallVector<mlir::Value> loopOperands; |
| 100 | loopOperands.push_back(low); |
| 101 | auto operands = loop.getIterOperands(); |
| 102 | loopOperands.append(operands.begin(), operands.end()); |
| 103 | loopOperands.push_back(iters); |
| 104 | |
| 105 | rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopOperands); |
| 106 | |
| 107 | // Last loop block |
| 108 | auto *terminator = lastBlock->getTerminator(); |
| 109 | rewriter.setInsertionPointToEnd(lastBlock); |
| 110 | auto iv = conditionalBlock->getArgument(0); |
| 111 | mlir::Value steppedIndex = |
| 112 | rewriter.create<mlir::arith::AddIOp>(loc, iv, step, iofAttr); |
| 113 | assert(steppedIndex && "must be a Value" ); |
| 114 | auto lastArg = conditionalBlock->getNumArguments() - 1; |
| 115 | auto itersLeft = conditionalBlock->getArgument(lastArg); |
| 116 | auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); |
| 117 | mlir::Value itersMinusOne = |
| 118 | rewriter.create<mlir::arith::SubIOp>(loc, itersLeft, one); |
| 119 | |
| 120 | llvm::SmallVector<mlir::Value> loopCarried; |
| 121 | loopCarried.push_back(steppedIndex); |
| 122 | auto begin = loop.getFinalValue() ? std::next(terminator->operand_begin()) |
| 123 | : terminator->operand_begin(); |
| 124 | loopCarried.append(begin, terminator->operand_end()); |
| 125 | loopCarried.push_back(itersMinusOne); |
| 126 | auto backEdge = |
| 127 | rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopCarried); |
| 128 | rewriter.eraseOp(terminator); |
| 129 | |
| 130 | // Copy loop annotations from the do loop to the loop back edge. |
| 131 | if (auto ann = loop.getLoopAnnotation()) |
| 132 | backEdge->setAttr("loop_annotation" , *ann); |
| 133 | |
| 134 | // Conditional block |
| 135 | rewriter.setInsertionPointToEnd(conditionalBlock); |
| 136 | auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); |
| 137 | auto comparison = rewriter.create<mlir::arith::CmpIOp>( |
| 138 | loc, arith::CmpIPredicate::sgt, itersLeft, zero); |
| 139 | |
| 140 | rewriter.create<mlir::cf::CondBranchOp>( |
| 141 | loc, comparison, firstBlock, llvm::ArrayRef<mlir::Value>(), endBlock, |
| 142 | llvm::ArrayRef<mlir::Value>()); |
| 143 | |
| 144 | // The result of the loop operation is the values of the condition block |
| 145 | // arguments except the induction variable on the last iteration. |
| 146 | auto args = loop.getFinalValue() |
| 147 | ? conditionalBlock->getArguments() |
| 148 | : conditionalBlock->getArguments().drop_front(); |
| 149 | rewriter.replaceOp(loop, args.drop_back()); |
| 150 | return success(); |
| 151 | } |
| 152 | |
| 153 | private: |
| 154 | bool forceLoopToExecuteOnce; |
| 155 | bool setNSW; |
| 156 | }; |
| 157 | |
| 158 | /// Convert `fir.if` to control-flow |
| 159 | class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> { |
| 160 | public: |
| 161 | using OpRewritePattern::OpRewritePattern; |
| 162 | |
| 163 | CfgIfConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce, bool setNSW) |
| 164 | : mlir::OpRewritePattern<fir::IfOp>(ctx) {} |
| 165 | |
| 166 | llvm::LogicalResult |
| 167 | matchAndRewrite(IfOp ifOp, mlir::PatternRewriter &rewriter) const override { |
| 168 | auto loc = ifOp.getLoc(); |
| 169 | |
| 170 | // Split the block containing the 'fir.if' into two parts. The part before |
| 171 | // will contain the condition, the part after will be the continuation |
| 172 | // point. |
| 173 | auto *condBlock = rewriter.getInsertionBlock(); |
| 174 | auto opPosition = rewriter.getInsertionPoint(); |
| 175 | auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); |
| 176 | mlir::Block *continueBlock; |
| 177 | if (ifOp.getNumResults() == 0) { |
| 178 | continueBlock = remainingOpsBlock; |
| 179 | } else { |
| 180 | continueBlock = rewriter.createBlock( |
| 181 | remainingOpsBlock, ifOp.getResultTypes(), |
| 182 | llvm::SmallVector<mlir::Location>(ifOp.getNumResults(), loc)); |
| 183 | rewriter.create<mlir::cf::BranchOp>(loc, remainingOpsBlock); |
| 184 | } |
| 185 | |
| 186 | // Move blocks from the "then" region to the region containing 'fir.if', |
| 187 | // place it before the continuation block, and branch to it. |
| 188 | auto &ifOpRegion = ifOp.getThenRegion(); |
| 189 | auto *ifOpBlock = &ifOpRegion.front(); |
| 190 | auto *ifOpTerminator = ifOpRegion.back().getTerminator(); |
| 191 | auto ifOpTerminatorOperands = ifOpTerminator->getOperands(); |
| 192 | rewriter.setInsertionPointToEnd(&ifOpRegion.back()); |
| 193 | rewriter.create<mlir::cf::BranchOp>(loc, continueBlock, |
| 194 | ifOpTerminatorOperands); |
| 195 | rewriter.eraseOp(ifOpTerminator); |
| 196 | rewriter.inlineRegionBefore(ifOpRegion, continueBlock); |
| 197 | |
| 198 | // Move blocks from the "else" region (if present) to the region containing |
| 199 | // 'fir.if', place it before the continuation block and branch to it. It |
| 200 | // will be placed after the "then" regions. |
| 201 | auto *otherwiseBlock = continueBlock; |
| 202 | auto &otherwiseRegion = ifOp.getElseRegion(); |
| 203 | if (!otherwiseRegion.empty()) { |
| 204 | otherwiseBlock = &otherwiseRegion.front(); |
| 205 | auto *otherwiseTerm = otherwiseRegion.back().getTerminator(); |
| 206 | auto otherwiseTermOperands = otherwiseTerm->getOperands(); |
| 207 | rewriter.setInsertionPointToEnd(&otherwiseRegion.back()); |
| 208 | rewriter.create<mlir::cf::BranchOp>(loc, continueBlock, |
| 209 | otherwiseTermOperands); |
| 210 | rewriter.eraseOp(otherwiseTerm); |
| 211 | rewriter.inlineRegionBefore(otherwiseRegion, continueBlock); |
| 212 | } |
| 213 | |
| 214 | rewriter.setInsertionPointToEnd(condBlock); |
| 215 | rewriter.create<mlir::cf::CondBranchOp>( |
| 216 | loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(), |
| 217 | otherwiseBlock, llvm::ArrayRef<mlir::Value>()); |
| 218 | rewriter.replaceOp(ifOp, continueBlock->getArguments()); |
| 219 | return success(); |
| 220 | } |
| 221 | }; |
| 222 | |
| 223 | /// Convert `fir.iter_while` to control-flow. |
| 224 | class CfgIterWhileConv : public mlir::OpRewritePattern<fir::IterWhileOp> { |
| 225 | public: |
| 226 | using OpRewritePattern::OpRewritePattern; |
| 227 | |
| 228 | CfgIterWhileConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce, |
| 229 | bool setNSW) |
| 230 | : mlir::OpRewritePattern<fir::IterWhileOp>(ctx), setNSW(setNSW) {} |
| 231 | |
| 232 | llvm::LogicalResult |
| 233 | matchAndRewrite(fir::IterWhileOp whileOp, |
| 234 | mlir::PatternRewriter &rewriter) const override { |
| 235 | auto loc = whileOp.getLoc(); |
| 236 | mlir::arith::IntegerOverflowFlags flags{}; |
| 237 | if (setNSW) |
| 238 | flags = bitEnumSet(flags, mlir::arith::IntegerOverflowFlags::nsw); |
| 239 | auto iofAttr = mlir::arith::IntegerOverflowFlagsAttr::get( |
| 240 | rewriter.getContext(), flags); |
| 241 | |
| 242 | // Start by splitting the block containing the 'fir.do_loop' into two parts. |
| 243 | // The part before will get the init code, the part after will be the end |
| 244 | // point. |
| 245 | auto *initBlock = rewriter.getInsertionBlock(); |
| 246 | auto initPosition = rewriter.getInsertionPoint(); |
| 247 | auto *endBlock = rewriter.splitBlock(initBlock, initPosition); |
| 248 | |
| 249 | // Use the first block of the loop body as the condition block since it is |
| 250 | // the block that has the induction variable and loop-carried values as |
| 251 | // arguments. Split out all operations from the first block into a new |
| 252 | // block. Move all body blocks from the loop body region to the region |
| 253 | // containing the loop. |
| 254 | auto *conditionBlock = &whileOp.getRegion().front(); |
| 255 | auto *firstBodyBlock = |
| 256 | rewriter.splitBlock(conditionBlock, conditionBlock->begin()); |
| 257 | auto *lastBodyBlock = &whileOp.getRegion().back(); |
| 258 | rewriter.inlineRegionBefore(whileOp.getRegion(), endBlock); |
| 259 | auto iv = conditionBlock->getArgument(0); |
| 260 | auto iterateVar = conditionBlock->getArgument(1); |
| 261 | |
| 262 | // Append the induction variable stepping logic to the last body block and |
| 263 | // branch back to the condition block. Loop-carried values are taken from |
| 264 | // operands of the loop terminator. |
| 265 | auto *terminator = lastBodyBlock->getTerminator(); |
| 266 | rewriter.setInsertionPointToEnd(lastBodyBlock); |
| 267 | auto step = whileOp.getStep(); |
| 268 | mlir::Value stepped = |
| 269 | rewriter.create<mlir::arith::AddIOp>(loc, iv, step, iofAttr); |
| 270 | assert(stepped && "must be a Value" ); |
| 271 | |
| 272 | llvm::SmallVector<mlir::Value> loopCarried; |
| 273 | loopCarried.push_back(stepped); |
| 274 | auto begin = whileOp.getFinalValue() |
| 275 | ? std::next(terminator->operand_begin()) |
| 276 | : terminator->operand_begin(); |
| 277 | loopCarried.append(begin, terminator->operand_end()); |
| 278 | rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, loopCarried); |
| 279 | rewriter.eraseOp(terminator); |
| 280 | |
| 281 | // Compute loop bounds before branching to the condition. |
| 282 | rewriter.setInsertionPointToEnd(initBlock); |
| 283 | auto lowerBound = whileOp.getLowerBound(); |
| 284 | auto upperBound = whileOp.getUpperBound(); |
| 285 | assert(lowerBound && upperBound && "must be a Value" ); |
| 286 | |
| 287 | // The initial values of loop-carried values is obtained from the operands |
| 288 | // of the loop operation. |
| 289 | llvm::SmallVector<mlir::Value> destOperands; |
| 290 | destOperands.push_back(lowerBound); |
| 291 | auto iterOperands = whileOp.getIterOperands(); |
| 292 | destOperands.append(iterOperands.begin(), iterOperands.end()); |
| 293 | rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, destOperands); |
| 294 | |
| 295 | // With the body block done, we can fill in the condition block. |
| 296 | rewriter.setInsertionPointToEnd(conditionBlock); |
| 297 | // The comparison depends on the sign of the step value. We fully expect |
| 298 | // this expression to be folded by the optimizer or LLVM. This expression |
| 299 | // is written this way so that `step == 0` always returns `false`. |
| 300 | auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); |
| 301 | auto compl0 = rewriter.create<mlir::arith::CmpIOp>( |
| 302 | loc, arith::CmpIPredicate::slt, zero, step); |
| 303 | auto compl1 = rewriter.create<mlir::arith::CmpIOp>( |
| 304 | loc, arith::CmpIPredicate::sle, iv, upperBound); |
| 305 | auto compl2 = rewriter.create<mlir::arith::CmpIOp>( |
| 306 | loc, arith::CmpIPredicate::slt, step, zero); |
| 307 | auto compl3 = rewriter.create<mlir::arith::CmpIOp>( |
| 308 | loc, arith::CmpIPredicate::sle, upperBound, iv); |
| 309 | auto cmp0 = rewriter.create<mlir::arith::AndIOp>(loc, compl0, compl1); |
| 310 | auto cmp1 = rewriter.create<mlir::arith::AndIOp>(loc, compl2, compl3); |
| 311 | auto cmp2 = rewriter.create<mlir::arith::OrIOp>(loc, cmp0, cmp1); |
| 312 | // Remember to AND in the early-exit bool. |
| 313 | auto comparison = |
| 314 | rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2); |
| 315 | rewriter.create<mlir::cf::CondBranchOp>( |
| 316 | loc, comparison, firstBodyBlock, llvm::ArrayRef<mlir::Value>(), |
| 317 | endBlock, llvm::ArrayRef<mlir::Value>()); |
| 318 | // The result of the loop operation is the values of the condition block |
| 319 | // arguments except the induction variable on the last iteration. |
| 320 | auto args = whileOp.getFinalValue() |
| 321 | ? conditionBlock->getArguments() |
| 322 | : conditionBlock->getArguments().drop_front(); |
| 323 | rewriter.replaceOp(whileOp, args); |
| 324 | return success(); |
| 325 | } |
| 326 | |
| 327 | private: |
| 328 | bool setNSW; |
| 329 | }; |
| 330 | |
| 331 | /// Convert FIR structured control flow ops to CFG ops. |
| 332 | class CfgConversion : public fir::impl::CFGConversionBase<CfgConversion> { |
| 333 | public: |
| 334 | using CFGConversionBase<CfgConversion>::CFGConversionBase; |
| 335 | |
| 336 | void runOnOperation() override { |
| 337 | auto *context = &this->getContext(); |
| 338 | mlir::RewritePatternSet patterns(context); |
| 339 | fir::populateCfgConversionRewrites(patterns, this->forceLoopToExecuteOnce, |
| 340 | this->setNSW); |
| 341 | mlir::ConversionTarget target(*context); |
| 342 | target.addLegalDialect<mlir::affine::AffineDialect, |
| 343 | mlir::cf::ControlFlowDialect, FIROpsDialect, |
| 344 | mlir::func::FuncDialect>(); |
| 345 | |
| 346 | // apply the patterns |
| 347 | target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>(); |
| 348 | target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); |
| 349 | if (mlir::failed(mlir::applyPartialConversion(this->getOperation(), target, |
| 350 | std::move(patterns)))) { |
| 351 | mlir::emitError(mlir::UnknownLoc::get(context), |
| 352 | "error in converting to CFG\n" ); |
| 353 | this->signalPassFailure(); |
| 354 | } |
| 355 | } |
| 356 | }; |
| 357 | |
| 358 | } // namespace |
| 359 | |
| 360 | /// Expose conversion rewriters to other passes |
| 361 | void fir::populateCfgConversionRewrites(mlir::RewritePatternSet &patterns, |
| 362 | bool forceLoopToExecuteOnce, |
| 363 | bool setNSW) { |
| 364 | patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>( |
| 365 | patterns.getContext(), forceLoopToExecuteOnce, setNSW); |
| 366 | } |
| 367 | |