1//===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert standard dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
14#include "../SPIRVCommon/Pattern.h"
15#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
19#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
20#include "mlir/IR/AffineMap.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Support/LogicalResult.h"
23#include "mlir/Transforms/DialectConversion.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/FormatVariadic.h"
26
27#define DEBUG_TYPE "cf-to-spirv-pattern"
28
29using namespace mlir;
30
31/// Legailze target block arguments.
32static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
33 PatternRewriter &rewriter,
34 const TypeConverter &converter) {
35 auto builder = OpBuilder::atBlockBegin(block: &block);
36 for (unsigned i = 0; i < block.getNumArguments(); ++i) {
37 BlockArgument arg = block.getArgument(i);
38 if (converter.isLegal(type: arg.getType()))
39 continue;
40 Type ty = arg.getType();
41 Type newTy = converter.convertType(t: ty);
42 if (!newTy) {
43 return rewriter.notifyMatchFailure(
44 arg&: op, msg: llvm::formatv(Fmt: "failed to legalize type for argument {0})", Vals&: arg));
45 }
46 unsigned argNum = arg.getArgNumber();
47 Location loc = arg.getLoc();
48 Value newArg = block.insertArgument(index: argNum, type: newTy, loc);
49 Value convertedValue = converter.materializeSourceConversion(
50 builder, loc: op->getLoc(), resultType: ty, inputs: newArg);
51 if (!convertedValue) {
52 return rewriter.notifyMatchFailure(
53 arg&: op, msg: llvm::formatv(Fmt: "failed to cast new argument {0} to type {1})",
54 Vals&: newArg, Vals&: ty));
55 }
56 arg.replaceAllUsesWith(newValue: convertedValue);
57 block.eraseArgument(index: argNum + 1);
58 }
59 return success();
60}
61
62//===----------------------------------------------------------------------===//
63// Operation conversion
64//===----------------------------------------------------------------------===//
65
66namespace {
67/// Converts cf.br to spirv.Branch.
68struct BranchOpPattern final : OpConversionPattern<cf::BranchOp> {
69 using OpConversionPattern::OpConversionPattern;
70
71 LogicalResult
72 matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
73 ConversionPatternRewriter &rewriter) const override {
74 if (failed(legalizeBlockArguments(*op.getDest(), op, rewriter,
75 *getTypeConverter())))
76 return failure();
77
78 rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
79 adaptor.getDestOperands());
80 return success();
81 }
82};
83
84/// Converts cf.cond_br to spirv.BranchConditional.
85struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> {
86 using OpConversionPattern::OpConversionPattern;
87
88 LogicalResult
89 matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
90 ConversionPatternRewriter &rewriter) const override {
91 if (failed(legalizeBlockArguments(*op.getTrueDest(), op, rewriter,
92 *getTypeConverter())))
93 return failure();
94
95 if (failed(legalizeBlockArguments(*op.getFalseDest(), op, rewriter,
96 *getTypeConverter())))
97 return failure();
98
99 rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
100 op, adaptor.getCondition(), op.getTrueDest(),
101 adaptor.getTrueDestOperands(), op.getFalseDest(),
102 adaptor.getFalseDestOperands());
103 return success();
104 }
105};
106} // namespace
107
108//===----------------------------------------------------------------------===//
109// Pattern population
110//===----------------------------------------------------------------------===//
111
112void mlir::cf::populateControlFlowToSPIRVPatterns(
113 SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
114 MLIRContext *context = patterns.getContext();
115
116 patterns.add<BranchOpPattern, CondBranchOpPattern>(arg&: typeConverter, args&: context);
117}
118

source code of mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp