1//===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert standard dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
14#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
17#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
18#include "mlir/IR/AffineMap.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Transforms/DialectConversion.h"
21#include "llvm/Support/FormatVariadic.h"
22
23#define DEBUG_TYPE "cf-to-spirv-pattern"
24
25using namespace mlir;
26
27/// Legailze target block arguments.
28static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
29 PatternRewriter &rewriter,
30 const TypeConverter &converter) {
31 auto builder = OpBuilder::atBlockBegin(block: &block);
32 for (unsigned i = 0; i < block.getNumArguments(); ++i) {
33 BlockArgument arg = block.getArgument(i);
34 if (converter.isLegal(type: arg.getType()))
35 continue;
36 Type ty = arg.getType();
37 Type newTy = converter.convertType(t: ty);
38 if (!newTy) {
39 return rewriter.notifyMatchFailure(
40 arg&: op, msg: llvm::formatv(Fmt: "failed to legalize type for argument {0})", Vals&: arg));
41 }
42 unsigned argNum = arg.getArgNumber();
43 Location loc = arg.getLoc();
44 Value newArg = block.insertArgument(index: argNum, type: newTy, loc);
45 Value convertedValue = converter.materializeSourceConversion(
46 builder, loc: op->getLoc(), resultType: ty, inputs: newArg);
47 if (!convertedValue) {
48 return rewriter.notifyMatchFailure(
49 arg&: op, msg: llvm::formatv(Fmt: "failed to cast new argument {0} to type {1})",
50 Vals&: newArg, Vals&: ty));
51 }
52 arg.replaceAllUsesWith(newValue: convertedValue);
53 block.eraseArgument(index: argNum + 1);
54 }
55 return success();
56}
57
58//===----------------------------------------------------------------------===//
59// Operation conversion
60//===----------------------------------------------------------------------===//
61
62namespace {
63/// Converts cf.br to spirv.Branch.
64struct BranchOpPattern final : OpConversionPattern<cf::BranchOp> {
65 using OpConversionPattern::OpConversionPattern;
66
67 LogicalResult
68 matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
69 ConversionPatternRewriter &rewriter) const override {
70 if (failed(Result: legalizeBlockArguments(block&: *op.getDest(), op, rewriter,
71 converter: *getTypeConverter())))
72 return failure();
73
74 rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, args: op.getDest(),
75 args: adaptor.getDestOperands());
76 return success();
77 }
78};
79
80/// Converts cf.cond_br to spirv.BranchConditional.
81struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> {
82 using OpConversionPattern::OpConversionPattern;
83
84 LogicalResult
85 matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
86 ConversionPatternRewriter &rewriter) const override {
87 if (failed(Result: legalizeBlockArguments(block&: *op.getTrueDest(), op, rewriter,
88 converter: *getTypeConverter())))
89 return failure();
90
91 if (failed(Result: legalizeBlockArguments(block&: *op.getFalseDest(), op, rewriter,
92 converter: *getTypeConverter())))
93 return failure();
94
95 rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
96 op, args: adaptor.getCondition(), args: op.getTrueDest(),
97 args: adaptor.getTrueDestOperands(), args: op.getFalseDest(),
98 args: adaptor.getFalseDestOperands());
99 return success();
100 }
101};
102} // namespace
103
104//===----------------------------------------------------------------------===//
105// Pattern population
106//===----------------------------------------------------------------------===//
107
108void mlir::cf::populateControlFlowToSPIRVPatterns(
109 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
110 MLIRContext *context = patterns.getContext();
111
112 patterns.add<BranchOpPattern, CondBranchOpPattern>(arg: typeConverter, args&: context);
113}
114

source code of mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp