1//===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert standard dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
14#include "../SPIRVCommon/Pattern.h"
15#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
19#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
20#include "mlir/IR/AffineMap.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include "llvm/Support/Debug.h"
24#include "llvm/Support/FormatVariadic.h"
25
26#define DEBUG_TYPE "cf-to-spirv-pattern"
27
28using namespace mlir;
29
30/// Legailze target block arguments.
31static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
32 PatternRewriter &rewriter,
33 const TypeConverter &converter) {
34 auto builder = OpBuilder::atBlockBegin(block: &block);
35 for (unsigned i = 0; i < block.getNumArguments(); ++i) {
36 BlockArgument arg = block.getArgument(i);
37 if (converter.isLegal(type: arg.getType()))
38 continue;
39 Type ty = arg.getType();
40 Type newTy = converter.convertType(t: ty);
41 if (!newTy) {
42 return rewriter.notifyMatchFailure(
43 arg&: op, msg: llvm::formatv(Fmt: "failed to legalize type for argument {0})", Vals&: arg));
44 }
45 unsigned argNum = arg.getArgNumber();
46 Location loc = arg.getLoc();
47 Value newArg = block.insertArgument(index: argNum, type: newTy, loc);
48 Value convertedValue = converter.materializeSourceConversion(
49 builder, loc: op->getLoc(), resultType: ty, inputs: newArg);
50 if (!convertedValue) {
51 return rewriter.notifyMatchFailure(
52 arg&: op, msg: llvm::formatv(Fmt: "failed to cast new argument {0} to type {1})",
53 Vals&: newArg, Vals&: ty));
54 }
55 arg.replaceAllUsesWith(newValue: convertedValue);
56 block.eraseArgument(index: argNum + 1);
57 }
58 return success();
59}
60
61//===----------------------------------------------------------------------===//
62// Operation conversion
63//===----------------------------------------------------------------------===//
64
65namespace {
66/// Converts cf.br to spirv.Branch.
67struct BranchOpPattern final : OpConversionPattern<cf::BranchOp> {
68 using OpConversionPattern::OpConversionPattern;
69
70 LogicalResult
71 matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
72 ConversionPatternRewriter &rewriter) const override {
73 if (failed(legalizeBlockArguments(*op.getDest(), op, rewriter,
74 *getTypeConverter())))
75 return failure();
76
77 rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
78 adaptor.getDestOperands());
79 return success();
80 }
81};
82
83/// Converts cf.cond_br to spirv.BranchConditional.
84struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> {
85 using OpConversionPattern::OpConversionPattern;
86
87 LogicalResult
88 matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
89 ConversionPatternRewriter &rewriter) const override {
90 if (failed(legalizeBlockArguments(*op.getTrueDest(), op, rewriter,
91 *getTypeConverter())))
92 return failure();
93
94 if (failed(legalizeBlockArguments(*op.getFalseDest(), op, rewriter,
95 *getTypeConverter())))
96 return failure();
97
98 rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
99 op, adaptor.getCondition(), op.getTrueDest(),
100 adaptor.getTrueDestOperands(), op.getFalseDest(),
101 adaptor.getFalseDestOperands());
102 return success();
103 }
104};
105} // namespace
106
107//===----------------------------------------------------------------------===//
108// Pattern population
109//===----------------------------------------------------------------------===//
110
111void mlir::cf::populateControlFlowToSPIRVPatterns(
112 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
113 MLIRContext *context = patterns.getContext();
114
115 patterns.add<BranchOpPattern, CondBranchOpPattern>(arg: typeConverter, args&: context);
116}
117

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp