1//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h"
10
11#include "mlir/Dialect/Tensor/IR/Tensor.h"
12#include "mlir/Interfaces/ValueBoundsOpInterface.h"
13
14using namespace mlir;
15
16namespace mlir {
17namespace tensor {
18namespace {
19
20struct CastOpInterface
21 : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
22 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
23 ValueBoundsConstraintSet &cstr) const {
24 auto castOp = cast<CastOp>(op);
25 assert(value == castOp.getResult() && "invalid value");
26
27 if (llvm::isa<RankedTensorType>(castOp.getResult().getType()) &&
28 llvm::isa<RankedTensorType>(castOp.getSource().getType())) {
29 cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim);
30 }
31 }
32};
33
34struct DimOpInterface
35 : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
36 void populateBoundsForIndexValue(Operation *op, Value value,
37 ValueBoundsConstraintSet &cstr) const {
38 auto dimOp = cast<DimOp>(op);
39 assert(value == dimOp.getResult() && "invalid value");
40
41 auto constIndex = dimOp.getConstantIndex();
42 if (!constIndex.has_value())
43 return;
44 cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex);
45 }
46};
47
48struct EmptyOpInterface
49 : public ValueBoundsOpInterface::ExternalModel<EmptyOpInterface, EmptyOp> {
50 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
51 ValueBoundsConstraintSet &cstr) const {
52 auto emptyOp = cast<EmptyOp>(op);
53 assert(value == emptyOp.getResult() && "invalid value");
54
55 cstr.bound(value)[dim] == emptyOp.getMixedSizes()[dim];
56 }
57};
58
59struct ExtractSliceOpInterface
60 : public ValueBoundsOpInterface::ExternalModel<ExtractSliceOpInterface,
61 ExtractSliceOp> {
62 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
63 ValueBoundsConstraintSet &cstr) const {
64 auto extractSliceOp = cast<ExtractSliceOp>(op);
65 assert(value == extractSliceOp.getResult() && "invalid value");
66
67 llvm::SmallBitVector dropped = extractSliceOp.getDroppedDims();
68 int64_t ctr = -1;
69 for (int64_t i = 0, e = extractSliceOp.getMixedSizes().size(); i < e; ++i) {
70 // Skip over rank-reduced dimensions.
71 if (!dropped.test(Idx: i))
72 ++ctr;
73 if (ctr == dim) {
74 cstr.bound(value)[dim] == extractSliceOp.getMixedSizes()[i];
75 return;
76 }
77 }
78 llvm_unreachable("could not find non-rank-reduced dim");
79 }
80};
81
82struct PadOpInterface
83 : public ValueBoundsOpInterface::ExternalModel<PadOpInterface, PadOp> {
84 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
85 ValueBoundsConstraintSet &cstr) const {
86 auto padOp = cast<PadOp>(op);
87 assert(value == padOp.getResult() && "invalid value");
88
89 AffineExpr srcSize = cstr.getExpr(padOp.getSource(), dim);
90 AffineExpr lowPad = cstr.getExpr(padOp.getMixedLowPad()[dim]);
91 AffineExpr highPad = cstr.getExpr(padOp.getMixedHighPad()[dim]);
92 cstr.bound(value)[dim] == srcSize + lowPad + highPad;
93 }
94};
95
96struct RankOpInterface
97 : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
98 void populateBoundsForIndexValue(Operation *op, Value value,
99 ValueBoundsConstraintSet &cstr) const {
100 auto rankOp = cast<RankOp>(op);
101 assert(value == rankOp.getResult() && "invalid value");
102
103 auto tensorType =
104 llvm::dyn_cast<RankedTensorType>(rankOp.getTensor().getType());
105 if (!tensorType)
106 return;
107 cstr.bound(value) == tensorType.getRank();
108 }
109};
110
111} // namespace
112} // namespace tensor
113} // namespace mlir
114
115void mlir::tensor::registerValueBoundsOpInterfaceExternalModels(
116 DialectRegistry &registry) {
117 registry.addExtension(extensionFn: +[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
118 tensor::CastOp::attachInterface<tensor::CastOpInterface>(*ctx);
119 tensor::DimOp::attachInterface<tensor::DimOpInterface>(*ctx);
120 tensor::EmptyOp::attachInterface<tensor::EmptyOpInterface>(*ctx);
121 tensor::ExtractSliceOp::attachInterface<tensor::ExtractSliceOpInterface>(
122 *ctx);
123 tensor::PadOp::attachInterface<tensor::PadOpInterface>(*ctx);
124 tensor::RankOp::attachInterface<tensor::RankOpInterface>(*ctx);
125 // Note: ValueBoundsOpInterface implementation is not required for ops that
126 // implement `DestinationStyleOpInterface` (for querying shaped OpResults).
127 });
128}
129

source code of mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp