1 | //===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements patterns to convert standard dialect to SPIR-V dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h" |
14 | #include "../SPIRVCommon/Pattern.h" |
15 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
18 | #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" |
19 | #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" |
20 | #include "mlir/IR/AffineMap.h" |
21 | #include "mlir/IR/PatternMatch.h" |
22 | #include "mlir/Support/LogicalResult.h" |
23 | #include "mlir/Transforms/DialectConversion.h" |
24 | #include "llvm/Support/Debug.h" |
25 | #include "llvm/Support/FormatVariadic.h" |
26 | |
27 | #define DEBUG_TYPE "cf-to-spirv-pattern" |
28 | |
29 | using namespace mlir; |
30 | |
31 | /// Legailze target block arguments. |
32 | static LogicalResult legalizeBlockArguments(Block &block, Operation *op, |
33 | PatternRewriter &rewriter, |
34 | const TypeConverter &converter) { |
35 | auto builder = OpBuilder::atBlockBegin(block: &block); |
36 | for (unsigned i = 0; i < block.getNumArguments(); ++i) { |
37 | BlockArgument arg = block.getArgument(i); |
38 | if (converter.isLegal(type: arg.getType())) |
39 | continue; |
40 | Type ty = arg.getType(); |
41 | Type newTy = converter.convertType(t: ty); |
42 | if (!newTy) { |
43 | return rewriter.notifyMatchFailure( |
44 | arg&: op, msg: llvm::formatv(Fmt: "failed to legalize type for argument {0})" , Vals&: arg)); |
45 | } |
46 | unsigned argNum = arg.getArgNumber(); |
47 | Location loc = arg.getLoc(); |
48 | Value newArg = block.insertArgument(index: argNum, type: newTy, loc); |
49 | Value convertedValue = converter.materializeSourceConversion( |
50 | builder, loc: op->getLoc(), resultType: ty, inputs: newArg); |
51 | if (!convertedValue) { |
52 | return rewriter.notifyMatchFailure( |
53 | arg&: op, msg: llvm::formatv(Fmt: "failed to cast new argument {0} to type {1})" , |
54 | Vals&: newArg, Vals&: ty)); |
55 | } |
56 | arg.replaceAllUsesWith(newValue: convertedValue); |
57 | block.eraseArgument(index: argNum + 1); |
58 | } |
59 | return success(); |
60 | } |
61 | |
62 | //===----------------------------------------------------------------------===// |
63 | // Operation conversion |
64 | //===----------------------------------------------------------------------===// |
65 | |
66 | namespace { |
67 | /// Converts cf.br to spirv.Branch. |
68 | struct BranchOpPattern final : OpConversionPattern<cf::BranchOp> { |
69 | using OpConversionPattern::OpConversionPattern; |
70 | |
71 | LogicalResult |
72 | matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor, |
73 | ConversionPatternRewriter &rewriter) const override { |
74 | if (failed(legalizeBlockArguments(*op.getDest(), op, rewriter, |
75 | *getTypeConverter()))) |
76 | return failure(); |
77 | |
78 | rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(), |
79 | adaptor.getDestOperands()); |
80 | return success(); |
81 | } |
82 | }; |
83 | |
84 | /// Converts cf.cond_br to spirv.BranchConditional. |
85 | struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> { |
86 | using OpConversionPattern::OpConversionPattern; |
87 | |
88 | LogicalResult |
89 | matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor, |
90 | ConversionPatternRewriter &rewriter) const override { |
91 | if (failed(legalizeBlockArguments(*op.getTrueDest(), op, rewriter, |
92 | *getTypeConverter()))) |
93 | return failure(); |
94 | |
95 | if (failed(legalizeBlockArguments(*op.getFalseDest(), op, rewriter, |
96 | *getTypeConverter()))) |
97 | return failure(); |
98 | |
99 | rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>( |
100 | op, adaptor.getCondition(), op.getTrueDest(), |
101 | adaptor.getTrueDestOperands(), op.getFalseDest(), |
102 | adaptor.getFalseDestOperands()); |
103 | return success(); |
104 | } |
105 | }; |
106 | } // namespace |
107 | |
108 | //===----------------------------------------------------------------------===// |
109 | // Pattern population |
110 | //===----------------------------------------------------------------------===// |
111 | |
112 | void mlir::cf::populateControlFlowToSPIRVPatterns( |
113 | SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { |
114 | MLIRContext *context = patterns.getContext(); |
115 | |
116 | patterns.add<BranchOpPattern, CondBranchOpPattern>(arg&: typeConverter, args&: context); |
117 | } |
118 | |