1//===- TestTransformsOps.cpp - Test Transforms ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines transform dialect operations for testing MLIR
10// transformations
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Transform/IR/TransformDialect.h"
16#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
17#include "mlir/Dialect/Utils/StaticValueUtils.h"
18#include "mlir/IR/OpDefinition.h"
19#include "mlir/Transforms/RegionUtils.h"
20
21#define GET_OP_CLASSES
22#include "TestTransformsOps.h.inc"
23
24using namespace mlir;
25using namespace mlir::transform;
26
27#define GET_OP_CLASSES
28#include "TestTransformsOps.cpp.inc"
29
30DiagnosedSilenceableFailure
31transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter,
32 TransformResults &TransformResults,
33 TransformState &state) {
34 Operation *op = *state.getPayloadOps(value: getOp()).begin();
35 Operation *moveBefore = *state.getPayloadOps(value: getInsertionPoint()).begin();
36 if (failed(Result: moveOperationDependencies(rewriter, op, insertionPoint: moveBefore))) {
37 auto listener = cast<ErrorCheckingTrackingListener>(Val: rewriter.getListener());
38 std::string errorMsg = listener->getLatestMatchFailureMessage();
39 (void)emitRemark(message: errorMsg);
40 }
41 return DiagnosedSilenceableFailure::success();
42}
43
44DiagnosedSilenceableFailure
45transform::TestMoveValueDefns::apply(TransformRewriter &rewriter,
46 TransformResults &TransformResults,
47 TransformState &state) {
48 SmallVector<Value> values;
49 for (auto tdValue : getValues()) {
50 values.push_back(Elt: *state.getPayloadValues(handleValue: tdValue).begin());
51 }
52 Operation *moveBefore = *state.getPayloadOps(value: getInsertionPoint()).begin();
53 if (failed(Result: moveValueDefinitions(rewriter, values, insertionPoint: moveBefore))) {
54 auto listener = cast<ErrorCheckingTrackingListener>(Val: rewriter.getListener());
55 std::string errorMsg = listener->getLatestMatchFailureMessage();
56 (void)emitRemark(message: errorMsg);
57 }
58 return DiagnosedSilenceableFailure::success();
59}
60
61//===----------------------------------------------------------------------===//
62// Test affine functionality.
63//===----------------------------------------------------------------------===//
64DiagnosedSilenceableFailure
65transform::TestMakeComposedFoldedAffineApply::applyToOne(
66 TransformRewriter &rewriter, affine::AffineApplyOp affineApplyOp,
67 ApplyToEachResultList &results, TransformState &state) {
68 Location loc = affineApplyOp.getLoc();
69 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
70 b&: rewriter, loc, map: affineApplyOp.getAffineMap(),
71 operands: getAsOpFoldResult(values: affineApplyOp.getOperands()),
72 /*composeAffineMin=*/true);
73 Value result;
74 if (auto v = dyn_cast<Value>(Val&: ofr)) {
75 result = v;
76 } else {
77 result = rewriter.create<arith::ConstantIndexOp>(
78 location: loc, args: getConstantIntValue(ofr).value());
79 }
80 results.push_back(op: result.getDefiningOp());
81 rewriter.replaceOp(op: affineApplyOp, newValues: result);
82 return DiagnosedSilenceableFailure::success();
83}
84
85//===----------------------------------------------------------------------===//
86// Extension
87//===----------------------------------------------------------------------===//
88namespace {
89
90class TestTransformsDialectExtension
91 : public transform::TransformDialectExtension<
92 TestTransformsDialectExtension> {
93public:
94 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformsDialectExtension)
95
96 using Base::Base;
97
98 void init() {
99 registerTransformOps<
100#define GET_OP_LIST
101#include "TestTransformsOps.cpp.inc"
102 >();
103 }
104};
105} // namespace
106
107namespace test {
108void registerTestTransformsTransformDialectExtension(
109 DialectRegistry &registry) {
110 registry.addExtensions<TestTransformsDialectExtension>();
111}
112} // namespace test
113

source code of mlir/test/lib/Transforms/TestTransformsOps.cpp