1//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Arith/Utils/Utils.h"
13#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
14#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
17
18using namespace mlir;
19
20namespace mlir {
21namespace tensor {
22namespace {
23/// Generate a runtime check for lb <= value < ub.
24Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
25 Value lb, Value ub) {
26 Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
27 location: loc, args: arith::CmpIPredicate::sge, args&: value, args&: lb);
28 Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
29 location: loc, args: arith::CmpIPredicate::slt, args&: value, args&: ub);
30 Value inBounds =
31 builder.createOrFold<arith::AndIOp>(location: loc, args&: inBounds1, args&: inBounds2);
32 return inBounds;
33}
34
35struct CastOpInterface
36 : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
37 CastOp> {
38 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
39 Location loc) const {
40 auto castOp = cast<CastOp>(Val: op);
41 auto srcType = cast<TensorType>(Val: castOp.getSource().getType());
42
43 // Nothing to check if the result is an unranked tensor.
44 auto resultType = dyn_cast<RankedTensorType>(Val: castOp.getType());
45 if (!resultType)
46 return;
47
48 if (isa<UnrankedTensorType>(Val: srcType)) {
49 // Check rank.
50 Value srcRank = builder.create<RankOp>(location: loc, args: castOp.getSource());
51 Value resultRank =
52 builder.create<arith::ConstantIndexOp>(location: loc, args: resultType.getRank());
53 Value isSameRank = builder.create<arith::CmpIOp>(
54 location: loc, args: arith::CmpIPredicate::eq, args&: srcRank, args&: resultRank);
55 builder.create<cf::AssertOp>(
56 location: loc, args&: isSameRank,
57 args: RuntimeVerifiableOpInterface::generateErrorMessage(op,
58 msg: "rank mismatch"));
59 }
60
61 // Check dimension sizes.
62 for (const auto &it : llvm::enumerate(First: resultType.getShape())) {
63 // Static dim size -> static/dynamic dim size does not need verification.
64 if (auto rankedSrcType = dyn_cast<RankedTensorType>(Val&: srcType))
65 if (!rankedSrcType.isDynamicDim(idx: it.index()))
66 continue;
67
68 // Static/dynamic dim size -> dynamic dim size does not need verification.
69 if (resultType.isDynamicDim(idx: it.index()))
70 continue;
71
72 Value srcDimSz =
73 builder.create<DimOp>(location: loc, args: castOp.getSource(), args: it.index());
74 Value resultDimSz =
75 builder.create<arith::ConstantIndexOp>(location: loc, args: it.value());
76 Value isSameSz = builder.create<arith::CmpIOp>(
77 location: loc, args: arith::CmpIPredicate::eq, args&: srcDimSz, args&: resultDimSz);
78 builder.create<cf::AssertOp>(
79 location: loc, args&: isSameSz,
80 args: RuntimeVerifiableOpInterface::generateErrorMessage(
81 op, msg: "size mismatch of dim " + std::to_string(val: it.index())));
82 }
83 }
84};
85
86struct DimOpInterface
87 : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
88 DimOp> {
89 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
90 Location loc) const {
91 auto dimOp = cast<DimOp>(Val: op);
92 Value rank = builder.create<RankOp>(location: loc, args: dimOp.getSource());
93 Value zero = builder.create<arith::ConstantIndexOp>(location: loc, args: 0);
94 builder.create<cf::AssertOp>(
95 location: loc, args: generateInBoundsCheck(builder, loc, value: dimOp.getIndex(), lb: zero, ub: rank),
96 args: RuntimeVerifiableOpInterface::generateErrorMessage(
97 op, msg: "index is out of bounds"));
98 }
99};
100
101/// Verifies that the indices on extract/insert ops are in-bounds of the
102/// tensor's index space: 0 <= index#i < dim#i
103template <typename OpTy>
104struct ExtractInsertOpInterface
105 : public RuntimeVerifiableOpInterface::ExternalModel<
106 ExtractInsertOpInterface<OpTy>, OpTy> {
107 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
108 Location loc) const {
109 auto extractInsertOp = cast<OpTy>(op);
110
111 Value tensor;
112 if constexpr (std::is_same_v<OpTy, ExtractOp>) {
113 tensor = extractInsertOp.getTensor();
114 } else if constexpr (std::is_same_v<OpTy, InsertOp>) {
115 tensor = extractInsertOp.getDest();
116 } else {
117 llvm_unreachable("invalid op");
118 }
119 auto tensorType = cast<RankedTensorType>(Val: tensor.getType());
120 auto rank = tensorType.getRank();
121 if (rank == 0) {
122 // Nothing to check for 0-d tensors.
123 return;
124 }
125
126 auto indices = extractInsertOp.getIndices();
127 auto zero = builder.create<arith::ConstantIndexOp>(location: loc, args: 0);
128 Value assertCond;
129 for (auto i : llvm::seq<int64_t>(Begin: 0, End: rank)) {
130 Value dimOp = builder.createOrFold<tensor::DimOp>(location: loc, args&: tensor, args&: i);
131 Value inBounds =
132 generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
133 assertCond =
134 i > 0 ? builder.createOrFold<arith::AndIOp>(location: loc, args&: assertCond, args&: inBounds)
135 : inBounds;
136 }
137 builder.create<cf::AssertOp>(
138 location: loc, args&: assertCond,
139 args: RuntimeVerifiableOpInterface::generateErrorMessage(
140 op, msg: "out-of-bounds access"));
141 }
142};
143
144struct ExtractSliceOpInterface
145 : public RuntimeVerifiableOpInterface::ExternalModel<
146 ExtractSliceOpInterface, ExtractSliceOp> {
147 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
148 Location loc) const {
149 auto extractSliceOp = cast<ExtractSliceOp>(Val: op);
150 RankedTensorType sourceType = extractSliceOp.getSource().getType();
151
152 // For each dimension, assert that:
153 // 0 <= offset < dim_size
154 // 0 <= offset + (size - 1) * stride < dim_size
155 Value zero = builder.create<arith::ConstantIndexOp>(location: loc, args: 0);
156 Value one = builder.create<arith::ConstantIndexOp>(location: loc, args: 1);
157 for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
158 Value offset = getValueOrCreateConstantIndexOp(
159 b&: builder, loc, ofr: extractSliceOp.getMixedOffsets()[i]);
160 Value size = getValueOrCreateConstantIndexOp(
161 b&: builder, loc, ofr: extractSliceOp.getMixedSizes()[i]);
162 Value stride = getValueOrCreateConstantIndexOp(
163 b&: builder, loc, ofr: extractSliceOp.getMixedStrides()[i]);
164
165 // Verify that offset is in-bounds.
166 Value dimSize = builder.createOrFold<tensor::DimOp>(
167 location: loc, args: extractSliceOp.getSource(), args&: i);
168 Value offsetInBounds =
169 generateInBoundsCheck(builder, loc, value: offset, lb: zero, ub: dimSize);
170 builder.create<cf::AssertOp>(
171 location: loc, args&: offsetInBounds,
172 args: RuntimeVerifiableOpInterface::generateErrorMessage(
173 op, msg: "offset " + std::to_string(val: i) + " is out-of-bounds"));
174
175 // Verify that slice does not run out-of-bounds.
176 Value sizeMinusOne = builder.create<arith::SubIOp>(location: loc, args&: size, args&: one);
177 Value sizeMinusOneTimesStride =
178 builder.create<arith::MulIOp>(location: loc, args&: sizeMinusOne, args&: stride);
179 Value lastPos =
180 builder.create<arith::AddIOp>(location: loc, args&: offset, args&: sizeMinusOneTimesStride);
181 Value lastPosInBounds =
182 generateInBoundsCheck(builder, loc, value: lastPos, lb: zero, ub: dimSize);
183 builder.create<cf::AssertOp>(
184 location: loc, args&: lastPosInBounds,
185 args: RuntimeVerifiableOpInterface::generateErrorMessage(
186 op, msg: "extract_slice runs out-of-bounds along dimension " +
187 std::to_string(val: i)));
188 }
189 }
190};
191} // namespace
192} // namespace tensor
193} // namespace mlir
194
195void mlir::tensor::registerRuntimeVerifiableOpInterfaceExternalModels(
196 DialectRegistry &registry) {
197 registry.addExtension(extensionFn: +[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
198 CastOp::attachInterface<CastOpInterface>(context&: *ctx);
199 DimOp::attachInterface<DimOpInterface>(context&: *ctx);
200 ExtractOp::attachInterface<ExtractInsertOpInterface<ExtractOp>>(context&: *ctx);
201 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(context&: *ctx);
202 InsertOp::attachInterface<ExtractInsertOpInterface<InsertOp>>(context&: *ctx);
203
204 // Load additional dialects of which ops may get created.
205 ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
206 });
207}
208

source code of mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp