1//===- UBToSPIRV.cpp - UB to SPIRV-V dialect conversion -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
10
11#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
12#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
13#include "mlir/Dialect/UB/IR/UBOps.h"
14#include "mlir/Pass/Pass.h"
15
16namespace mlir {
17#define GEN_PASS_DEF_UBTOSPIRVCONVERSIONPASS
18#include "mlir/Conversion/Passes.h.inc"
19} // namespace mlir
20
21using namespace mlir;
22
23namespace {
24
25struct PoisonOpLowering final : OpConversionPattern<ub::PoisonOp> {
26 using OpConversionPattern::OpConversionPattern;
27
28 LogicalResult
29 matchAndRewrite(ub::PoisonOp op, OpAdaptor,
30 ConversionPatternRewriter &rewriter) const override {
31 Type origType = op.getType();
32 if (!origType.isIntOrIndexOrFloat())
33 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
34 diag << "unsupported type " << origType;
35 });
36
37 Type resType = getTypeConverter()->convertType(origType);
38 if (!resType)
39 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
40 diag << "failed to convert result type " << origType;
41 });
42
43 rewriter.replaceOpWithNewOp<spirv::UndefOp>(op, resType);
44 return success();
45 }
46};
47
48} // namespace
49
50//===----------------------------------------------------------------------===//
51// Pass Definition
52//===----------------------------------------------------------------------===//
53
54namespace {
55struct UBToSPIRVConversionPass final
56 : impl::UBToSPIRVConversionPassBase<UBToSPIRVConversionPass> {
57 using Base::Base;
58
59 void runOnOperation() override {
60 Operation *op = getOperation();
61 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
62 std::unique_ptr<SPIRVConversionTarget> target =
63 SPIRVConversionTarget::get(targetAttr);
64
65 SPIRVConversionOptions options;
66 SPIRVTypeConverter typeConverter(targetAttr, options);
67
68 RewritePatternSet patterns(&getContext());
69 ub::populateUBToSPIRVConversionPatterns(converter&: typeConverter, patterns);
70
71 if (failed(applyPartialConversion(op, *target, std::move(patterns))))
72 signalPassFailure();
73 }
74};
75} // namespace
76
77//===----------------------------------------------------------------------===//
78// Pattern Population
79//===----------------------------------------------------------------------===//
80
81void mlir::ub::populateUBToSPIRVConversionPatterns(
82 SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
83 patterns.add<PoisonOpLowering>(arg&: converter, args: patterns.getContext());
84}
85

source code of mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp