1//===- TestTransformsOps.cpp - Test Transforms ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines transform dialect operations for testing MLIR
10// transformations
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
15#include "mlir/Dialect/Transform/IR/TransformDialect.h"
16#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
17#include "mlir/Transforms/RegionUtils.h"
18
19#define GET_OP_CLASSES
20#include "TestTransformsOps.h.inc"
21
22using namespace mlir;
23using namespace mlir::transform;
24
25#define GET_OP_CLASSES
26#include "TestTransformsOps.cpp.inc"
27
28DiagnosedSilenceableFailure
29transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter,
30 TransformResults &TransformResults,
31 TransformState &state) {
32 Operation *op = *state.getPayloadOps(getOp()).begin();
33 Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
34 if (failed(moveOperationDependencies(rewriter, op, moveBefore))) {
35 auto listener = cast<ErrorCheckingTrackingListener>(rewriter.getListener());
36 std::string errorMsg = listener->getLatestMatchFailureMessage();
37 (void)emitRemark(errorMsg);
38 }
39 return DiagnosedSilenceableFailure::success();
40}
41
42DiagnosedSilenceableFailure
43transform::TestMoveValueDefns::apply(TransformRewriter &rewriter,
44 TransformResults &TransformResults,
45 TransformState &state) {
46 SmallVector<Value> values;
47 for (auto tdValue : getValues()) {
48 values.push_back(*state.getPayloadValues(tdValue).begin());
49 }
50 Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
51 if (failed(moveValueDefinitions(rewriter, values, moveBefore))) {
52 auto listener = cast<ErrorCheckingTrackingListener>(rewriter.getListener());
53 std::string errorMsg = listener->getLatestMatchFailureMessage();
54 (void)emitRemark(errorMsg);
55 }
56 return DiagnosedSilenceableFailure::success();
57}
58
59namespace {
60
61class TestTransformsDialectExtension
62 : public transform::TransformDialectExtension<
63 TestTransformsDialectExtension> {
64public:
65 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformsDialectExtension)
66
67 using Base::Base;
68
69 void init() {
70 registerTransformOps<
71#define GET_OP_LIST
72#include "TestTransformsOps.cpp.inc"
73 >();
74 }
75};
76} // namespace
77
78namespace test {
79void registerTestTransformsTransformDialectExtension(
80 DialectRegistry &registry) {
81 registry.addExtensions<TestTransformsDialectExtension>();
82}
83} // namespace test
84

source code of mlir/test/lib/Transforms/TestTransformsOps.cpp