1//===- SubsetInsertionOpInterfaceImpl.cpp - Tensor subsets ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
10
11#include "mlir/Dialect/Linalg/IR/Linalg.h"
12#include "mlir/Interfaces/SubsetOpInterface.h"
13
14using namespace mlir;
15using namespace mlir::linalg;
16
17namespace {
18struct LinalgCopyOpSubsetOpInterface
19 : public SubsetOpInterface::ExternalModel<LinalgCopyOpSubsetOpInterface,
20 linalg::CopyOp> {
21 bool operatesOnEquivalentSubset(
22 Operation *op, SubsetOpInterface candidate,
23 function_ref<bool(Value, Value)> equivalenceFn) const {
24 // linalg.copy operates on the entire destination tensor.
25 if (auto otherCopyOp = dyn_cast<linalg::CopyOp>(candidate.getOperation()))
26 return equivalenceFn(cast<linalg::CopyOp>(op).getOutputs()[0],
27 otherCopyOp.getOutputs()[0]);
28 // In the absence of an analysis, "false" is a conservative way to implement
29 // this interface.
30 return false;
31 }
32
33 bool operatesOnDisjointSubset(
34 Operation *op, SubsetOpInterface candidate,
35 function_ref<bool(Value, Value)> equivalenceFn) const {
36 // In the absence of an analysis, "false" is a conservative way to implement
37 // this interface.
38 return false;
39 }
40};
41
42struct LinalgCopyOpInterface
43 : public SubsetInsertionOpInterface::ExternalModel<LinalgCopyOpInterface,
44 linalg::CopyOp> {
45 OpOperand &getSourceOperand(Operation *op) const {
46 auto copyOp = cast<CopyOp>(op);
47 assert(copyOp.getInputs().size() == 1 && "expected single input");
48 return copyOp.getInputsMutable()[0];
49 }
50
51 bool
52 isEquivalentSubset(Operation *op, Value candidate,
53 function_ref<bool(Value, Value)> equivalenceFn) const {
54 auto copyOp = cast<CopyOp>(op);
55 assert(copyOp.getOutputs().size() == 1 && "expected single output");
56 return equivalenceFn(candidate, copyOp.getOutputs()[0]);
57 }
58
59 Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
60 Location loc) const {
61 auto copyOp = cast<CopyOp>(op);
62 assert(copyOp.getOutputs().size() == 1 && "expected single output");
63 return copyOp.getOutputs()[0];
64 }
65
66 SmallVector<Value>
67 getValuesNeededToBuildSubsetExtraction(Operation *op) const {
68 auto copyOp = cast<CopyOp>(op);
69 assert(copyOp.getOutputs().size() == 1 && "expected single output");
70 return {copyOp.getOutputs()[0]};
71 }
72};
73} // namespace
74
75void mlir::linalg::registerSubsetOpInterfaceExternalModels(
76 DialectRegistry &registry) {
77 registry.addExtension(extensionFn: +[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
78 linalg::CopyOp::attachInterface<LinalgCopyOpSubsetOpInterface>(*ctx);
79 linalg::CopyOp::attachInterface<LinalgCopyOpInterface>(*ctx);
80 });
81}
82

source code of mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp