1//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h"
10
11#include "mlir/Dialect/Tensor/IR/Tensor.h"
12#include "mlir/Interfaces/ValueBoundsOpInterface.h"
13
14using namespace mlir;
15
16namespace mlir {
17namespace tensor {
18namespace {
19
20struct CastOpInterface
21 : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
22 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
23 ValueBoundsConstraintSet &cstr) const {
24 auto castOp = cast<CastOp>(op);
25 assert(value == castOp.getResult() && "invalid value");
26
27 if (llvm::isa<RankedTensorType>(castOp.getResult().getType()) &&
28 llvm::isa<RankedTensorType>(castOp.getSource().getType())) {
29 cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim);
30 }
31 }
32};
33
34struct DimOpInterface
35 : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
36 void populateBoundsForIndexValue(Operation *op, Value value,
37 ValueBoundsConstraintSet &cstr) const {
38 auto dimOp = cast<DimOp>(op);
39 assert(value == dimOp.getResult() && "invalid value");
40
41 cstr.bound(value) >= 0;
42 auto constIndex = dimOp.getConstantIndex();
43 if (!constIndex.has_value())
44 return;
45 cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex);
46 }
47};
48
49struct EmptyOpInterface
50 : public ValueBoundsOpInterface::ExternalModel<EmptyOpInterface, EmptyOp> {
51 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
52 ValueBoundsConstraintSet &cstr) const {
53 auto emptyOp = cast<EmptyOp>(op);
54 assert(value == emptyOp.getResult() && "invalid value");
55
56 cstr.bound(value)[dim] == emptyOp.getMixedSizes()[dim];
57 }
58};
59
60struct ExtractSliceOpInterface
61 : public ValueBoundsOpInterface::ExternalModel<ExtractSliceOpInterface,
62 ExtractSliceOp> {
63 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
64 ValueBoundsConstraintSet &cstr) const {
65 auto extractSliceOp = cast<ExtractSliceOp>(op);
66 assert(value == extractSliceOp.getResult() && "invalid value");
67
68 llvm::SmallBitVector dropped = extractSliceOp.getDroppedDims();
69 int64_t ctr = -1;
70 for (int64_t i = 0, e = extractSliceOp.getMixedSizes().size(); i < e; ++i) {
71 // Skip over rank-reduced dimensions.
72 if (!dropped.test(Idx: i))
73 ++ctr;
74 if (ctr == dim) {
75 cstr.bound(value)[dim] == extractSliceOp.getMixedSizes()[i];
76 return;
77 }
78 }
79 llvm_unreachable("could not find non-rank-reduced dim");
80 }
81};
82
83struct PadOpInterface
84 : public ValueBoundsOpInterface::ExternalModel<PadOpInterface, PadOp> {
85 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
86 ValueBoundsConstraintSet &cstr) const {
87 auto padOp = cast<PadOp>(op);
88 assert(value == padOp.getResult() && "invalid value");
89
90 AffineExpr srcSize = cstr.getExpr(padOp.getSource(), dim);
91 AffineExpr lowPad = cstr.getExpr(padOp.getMixedLowPad()[dim]);
92 AffineExpr highPad = cstr.getExpr(padOp.getMixedHighPad()[dim]);
93 cstr.bound(value)[dim] == srcSize + lowPad + highPad;
94 }
95};
96
97struct RankOpInterface
98 : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
99 void populateBoundsForIndexValue(Operation *op, Value value,
100 ValueBoundsConstraintSet &cstr) const {
101 auto rankOp = cast<RankOp>(op);
102 assert(value == rankOp.getResult() && "invalid value");
103
104 auto tensorType =
105 llvm::dyn_cast<RankedTensorType>(rankOp.getTensor().getType());
106 if (!tensorType)
107 return;
108 cstr.bound(value) == tensorType.getRank();
109 }
110};
111
112} // namespace
113} // namespace tensor
114} // namespace mlir
115
116void mlir::tensor::registerValueBoundsOpInterfaceExternalModels(
117 DialectRegistry &registry) {
118 registry.addExtension(extensionFn: +[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
119 tensor::CastOp::attachInterface<tensor::CastOpInterface>(*ctx);
120 tensor::DimOp::attachInterface<tensor::DimOpInterface>(*ctx);
121 tensor::EmptyOp::attachInterface<tensor::EmptyOpInterface>(*ctx);
122 tensor::ExtractSliceOp::attachInterface<tensor::ExtractSliceOpInterface>(
123 *ctx);
124 tensor::PadOp::attachInterface<tensor::PadOpInterface>(*ctx);
125 tensor::RankOp::attachInterface<tensor::RankOpInterface>(*ctx);
126 // Note: ValueBoundsOpInterface implementation is not required for ops that
127 // implement `DestinationStyleOpInterface` (for querying shaped OpResults).
128 });
129}
130

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp