1//===- SparseTensorTransformOps.cpp - sparse tensor transform ops impl ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
10#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
11#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
12
13using namespace mlir;
14using namespace mlir::sparse_tensor;
15
16//===----------------------------------------------------------------------===//
17// Transform op implementation
18//===----------------------------------------------------------------------===//
19
20DiagnosedSilenceableFailure transform::MatchSparseInOut::matchOperation(
21 mlir::Operation *current, mlir::transform::TransformResults &results,
22 mlir::transform::TransformState &state) {
23 bool hasSparseInOut = hasAnySparseOperandOrResult(current);
24 if (!hasSparseInOut) {
25 return emitSilenceableFailure(current->getLoc(),
26 "operation has no sparse input or output");
27 }
28 results.set(cast<OpResult>(getResult()), state.getPayloadOps(getTarget()));
29 return DiagnosedSilenceableFailure::success();
30}
31
32//===----------------------------------------------------------------------===//
33// Transform op registration
34//===----------------------------------------------------------------------===//
35
36namespace {
37class SparseTensorTransformDialectExtension
38 : public transform::TransformDialectExtension<
39 SparseTensorTransformDialectExtension> {
40public:
41 SparseTensorTransformDialectExtension() {
42 declareGeneratedDialect<sparse_tensor::SparseTensorDialect>();
43 registerTransformOps<
44#define GET_OP_LIST
45#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp.inc"
46 >();
47 }
48};
49} // namespace
50
51#define GET_OP_CLASSES
52#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp.inc"
53
54void mlir::sparse_tensor::registerTransformDialectExtension(
55 DialectRegistry &registry) {
56 registry.addExtensions<SparseTensorTransformDialectExtension>();
57}
58

source code of mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp