1//===- SubsetOpInterface.cpp - Tensor Subsets -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Interfaces/SubsetOpInterface.h"
10#include "mlir/Interfaces/DestinationStyleOpInterface.h"
11#include "mlir/Interfaces/ValueBoundsOpInterface.h"
12
13#include "mlir/Interfaces/SubsetOpInterface.cpp.inc"
14
15using namespace mlir;
16
17OpOperand &detail::defaultGetDestinationOperand(Operation *op) {
18 auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
19 assert(dstOp && "getDestination must be implemented for non-DPS ops");
20 assert(
21 dstOp.getNumDpsInits() == 1 &&
22 "getDestination must be implemented for ops with 0 or more than 1 init");
23 return *dstOp.getDpsInitOperand(0);
24}
25
26OpResult detail::defaultGetUpdatedDestination(Operation *op) {
27 auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
28 assert(dstOp && "getUpdatedDestination must be implemented for non-DPS ops");
29 auto insertionOp = cast<SubsetInsertionOpInterface>(op);
30 return dstOp.getTiedOpResult(&insertionOp.getDestinationOperand());
31}
32
33bool detail::defaultIsEquivalentSubset(
34 Operation *op, Value candidate,
35 function_ref<bool(Value, Value)> equivalenceFn) {
36 assert(isa<SubsetInsertionOpInterface>(op) &&
37 "expected SubsetInsertionOpInterface");
38 if (!candidate.getDefiningOp<SubsetExtractionOpInterface>())
39 return false;
40 return cast<SubsetOpInterface>(op).operatesOnEquivalentSubset(
41 candidate.getDefiningOp<SubsetOpInterface>(), equivalenceFn);
42}
43
44bool detail::defaultOperatesOnEquivalentSubset(
45 Operation *op, SubsetOpInterface candidate,
46 function_ref<bool(Value, Value)> equivalenceFn) {
47 auto subsetOp = cast<SubsetOpInterface>(op);
48 FailureOr<HyperrectangularSlice> slice =
49 subsetOp.getAccessedHyperrectangularSlice();
50 assert(succeeded(slice) &&
51 "operatesOnEquivalentSubset must be implemented if "
52 "getAccessedHyperrectangularSlice is not implemented");
53 FailureOr<HyperrectangularSlice> otherSlice =
54 candidate.getAccessedHyperrectangularSlice();
55 if (failed(result: otherSlice))
56 return false;
57 if (!equivalenceFn(subsetOp.getTensorContainer(),
58 candidate.getTensorContainer()))
59 return false;
60 FailureOr<bool> equivalent = ValueBoundsConstraintSet::areEquivalentSlices(
61 ctx: op->getContext(), slice1: *slice, slice2: *otherSlice);
62 return succeeded(result: equivalent) && *equivalent;
63}
64
65bool detail::defaultOperatesOnDisjointSubset(
66 Operation *op, SubsetOpInterface candidate,
67 function_ref<bool(Value, Value)> equivalenceFn) {
68 auto subsetOp = cast<SubsetOpInterface>(op);
69 FailureOr<HyperrectangularSlice> slice =
70 subsetOp.getAccessedHyperrectangularSlice();
71 assert(succeeded(slice) &&
72 "defaultOperatesOnDisjointSubset must be implemented if "
73 "getAccessedHyperrectangularSlice is not implemented");
74 FailureOr<HyperrectangularSlice> otherSlice =
75 candidate.getAccessedHyperrectangularSlice();
76 if (failed(result: otherSlice))
77 return false;
78 if (!equivalenceFn(subsetOp.getTensorContainer(),
79 candidate.getTensorContainer()))
80 return false;
81 FailureOr<bool> overlapping = ValueBoundsConstraintSet::areOverlappingSlices(
82 ctx: op->getContext(), slice1: *slice, slice2: *otherSlice);
83 return succeeded(result: overlapping) && !*overlapping;
84}
85
86Value detail::getTensorContainer(Operation *op) {
87 if (auto insertionOp = dyn_cast<::mlir::SubsetInsertionOpInterface>(op))
88 return insertionOp.getDestinationOperand().get();
89 return cast<::mlir::SubsetExtractionOpInterface>(op).getSourceOperand().get();
90}
91
92LogicalResult detail::verifySubsetOpInterface(SubsetOpInterface op) {
93 if (!(isa<SubsetExtractionOpInterface>(op.getOperation()) ^
94 isa<SubsetInsertionOpInterface>(op.getOperation())))
95 return op->emitOpError(
96 "SubsetOpInterface ops must implement either "
97 "SubsetExtractionOpInterface or SubsetInsertionOpInterface");
98 return success();
99}
100
101LogicalResult
102detail::verifySubsetExtractionOpInterface(SubsetExtractionOpInterface op) {
103 if (op->getNumResults() != 1)
104 return op->emitOpError(
105 "SubsetExtractionOpInterface ops must have one result");
106 return success();
107}
108

source code of mlir/lib/Interfaces/SubsetOpInterface.cpp