1//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
10
11#include "mlir/Dialect/MemRef/IR/MemRef.h"
12#include "mlir/Interfaces/ValueBoundsOpInterface.h"
13
14using namespace mlir;
15
16namespace mlir {
17namespace memref {
18namespace {
19
20template <typename OpTy>
21struct AllocOpInterface
22 : public ValueBoundsOpInterface::ExternalModel<AllocOpInterface<OpTy>,
23 OpTy> {
24 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
25 ValueBoundsConstraintSet &cstr) const {
26 auto allocOp = cast<OpTy>(op);
27 assert(value == allocOp.getResult() && "invalid value");
28
29 cstr.bound(value)[dim] == allocOp.getMixedSizes()[dim];
30 }
31};
32
33struct CastOpInterface
34 : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
35 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
36 ValueBoundsConstraintSet &cstr) const {
37 auto castOp = cast<CastOp>(op);
38 assert(value == castOp.getResult() && "invalid value");
39
40 if (llvm::isa<MemRefType>(castOp.getResult().getType()) &&
41 llvm::isa<MemRefType>(castOp.getSource().getType())) {
42 cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim);
43 }
44 }
45};
46
47struct DimOpInterface
48 : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
49 void populateBoundsForIndexValue(Operation *op, Value value,
50 ValueBoundsConstraintSet &cstr) const {
51 auto dimOp = cast<DimOp>(op);
52 assert(value == dimOp.getResult() && "invalid value");
53
54 cstr.bound(value) >= 0;
55 auto constIndex = dimOp.getConstantIndex();
56 if (!constIndex.has_value())
57 return;
58 cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex);
59 }
60};
61
62struct GetGlobalOpInterface
63 : public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface,
64 GetGlobalOp> {
65 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
66 ValueBoundsConstraintSet &cstr) const {
67 auto getGlobalOp = cast<GetGlobalOp>(op);
68 assert(value == getGlobalOp.getResult() && "invalid value");
69
70 auto type = getGlobalOp.getType();
71 assert(!type.isDynamicDim(dim) && "expected static dim");
72 cstr.bound(value)[dim] == type.getDimSize(dim);
73 }
74};
75
76struct RankOpInterface
77 : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
78 void populateBoundsForIndexValue(Operation *op, Value value,
79 ValueBoundsConstraintSet &cstr) const {
80 auto rankOp = cast<RankOp>(op);
81 assert(value == rankOp.getResult() && "invalid value");
82
83 auto memrefType = llvm::dyn_cast<MemRefType>(rankOp.getMemref().getType());
84 if (!memrefType)
85 return;
86 cstr.bound(value) == memrefType.getRank();
87 }
88};
89
90struct SubViewOpInterface
91 : public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface,
92 SubViewOp> {
93 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
94 ValueBoundsConstraintSet &cstr) const {
95 auto subViewOp = cast<SubViewOp>(op);
96 assert(value == subViewOp.getResult() && "invalid value");
97
98 llvm::SmallBitVector dropped = subViewOp.getDroppedDims();
99 int64_t ctr = -1;
100 for (int64_t i = 0, e = subViewOp.getMixedSizes().size(); i < e; ++i) {
101 // Skip over rank-reduced dimensions.
102 if (!dropped.test(Idx: i))
103 ++ctr;
104 if (ctr == dim) {
105 cstr.bound(value)[dim] == subViewOp.getMixedSizes()[i];
106 return;
107 }
108 }
109 llvm_unreachable("could not find non-rank-reduced dim");
110 }
111};
112
113} // namespace
114} // namespace memref
115} // namespace mlir
116
117void mlir::memref::registerValueBoundsOpInterfaceExternalModels(
118 DialectRegistry &registry) {
119 registry.addExtension(extensionFn: +[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
120 memref::AllocOp::attachInterface<memref::AllocOpInterface<memref::AllocOp>>(
121 *ctx);
122 memref::AllocaOp::attachInterface<
123 memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
124 memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
125 memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
126 memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
127 memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx);
128 memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx);
129 });
130}
131

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp