1//===- SparseTensorTransformOps.cpp - sparse tensor transform ops impl ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
10#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
11#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
12
13using namespace mlir;
14using namespace mlir::sparse_tensor;
15
16//===----------------------------------------------------------------------===//
17// Transform op implementation
18//===----------------------------------------------------------------------===//
19
20DiagnosedSilenceableFailure transform::MatchSparseInOut::matchOperation(
21 mlir::Operation *current, mlir::transform::TransformResults &results,
22 mlir::transform::TransformState &state) {
23 bool hasSparseInOut = hasAnySparseOperandOrResult(current);
24 if (!hasSparseInOut) {
25 return emitSilenceableFailure(current->getLoc(),
26 "operation has no sparse input or output");
27 }
28 results.set(cast<OpResult>(getResult()), state.getPayloadOps(getTarget()));
29 return DiagnosedSilenceableFailure::success();
30}
31
32//===----------------------------------------------------------------------===//
33// Transform op registration
34//===----------------------------------------------------------------------===//
35
36namespace {
37class SparseTensorTransformDialectExtension
38 : public transform::TransformDialectExtension<
39 SparseTensorTransformDialectExtension> {
40public:
41 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
42 SparseTensorTransformDialectExtension)
43
44 SparseTensorTransformDialectExtension() {
45 declareGeneratedDialect<sparse_tensor::SparseTensorDialect>();
46 registerTransformOps<
47#define GET_OP_LIST
48#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp.inc"
49 >();
50 }
51};
52} // namespace
53
54#define GET_OP_CLASSES
55#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp.inc"
56
57void mlir::sparse_tensor::registerTransformDialectExtension(
58 DialectRegistry &registry) {
59 registry.addExtensions<SparseTensorTransformDialectExtension>();
60}
61

source code of mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp