1//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Arith/Utils/Utils.h"
13#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
14#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/Dialect/Utils/IndexingUtils.h"
17#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
18
19using namespace mlir;
20
21namespace mlir {
22namespace tensor {
23namespace {
24/// Generate a runtime check for lb <= value < ub.
25Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
26 Value lb, Value ub) {
27 Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
28 loc, arith::CmpIPredicate::sge, value, lb);
29 Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
30 loc, arith::CmpIPredicate::slt, value, ub);
31 Value inBounds =
32 builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
33 return inBounds;
34}
35
36struct CastOpInterface
37 : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
38 CastOp> {
39 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
40 Location loc) const {
41 auto castOp = cast<CastOp>(op);
42 auto srcType = cast<TensorType>(castOp.getSource().getType());
43
44 // Nothing to check if the result is an unranked tensor.
45 auto resultType = dyn_cast<RankedTensorType>(castOp.getType());
46 if (!resultType)
47 return;
48
49 if (isa<UnrankedTensorType>(srcType)) {
50 // Check rank.
51 Value srcRank = builder.create<RankOp>(loc, castOp.getSource());
52 Value resultRank =
53 builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
54 Value isSameRank = builder.create<arith::CmpIOp>(
55 loc, arith::CmpIPredicate::eq, srcRank, resultRank);
56 builder.create<cf::AssertOp>(
57 loc, isSameRank,
58 RuntimeVerifiableOpInterface::generateErrorMessage(op,
59 "rank mismatch"));
60 }
61
62 // Check dimension sizes.
63 for (const auto &it : llvm::enumerate(resultType.getShape())) {
64 // Static dim size -> static/dynamic dim size does not need verification.
65 if (auto rankedSrcType = dyn_cast<RankedTensorType>(srcType))
66 if (!rankedSrcType.isDynamicDim(it.index()))
67 continue;
68
69 // Static/dynamic dim size -> dynamic dim size does not need verification.
70 if (resultType.isDynamicDim(it.index()))
71 continue;
72
73 Value srcDimSz =
74 builder.create<DimOp>(loc, castOp.getSource(), it.index());
75 Value resultDimSz =
76 builder.create<arith::ConstantIndexOp>(loc, it.value());
77 Value isSameSz = builder.create<arith::CmpIOp>(
78 loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
79 builder.create<cf::AssertOp>(
80 loc, isSameSz,
81 RuntimeVerifiableOpInterface::generateErrorMessage(
82 op, "size mismatch of dim " + std::to_string(it.index())));
83 }
84 }
85};
86
87struct DimOpInterface
88 : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
89 DimOp> {
90 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
91 Location loc) const {
92 auto dimOp = cast<DimOp>(op);
93 Value rank = builder.create<RankOp>(loc, dimOp.getSource());
94 Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
95 builder.create<cf::AssertOp>(
96 loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
97 RuntimeVerifiableOpInterface::generateErrorMessage(
98 op, "index is out of bounds"));
99 }
100};
101
102/// Verifies that the indices on extract/insert ops are in-bounds of the
103/// tensor's index space: 0 <= index#i < dim#i
104template <typename OpTy>
105struct ExtractInsertOpInterface
106 : public RuntimeVerifiableOpInterface::ExternalModel<
107 ExtractInsertOpInterface<OpTy>, OpTy> {
108 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
109 Location loc) const {
110 auto extractInsertOp = cast<OpTy>(op);
111
112 Value tensor;
113 if constexpr (std::is_same_v<OpTy, ExtractOp>) {
114 tensor = extractInsertOp.getTensor();
115 } else if constexpr (std::is_same_v<OpTy, InsertOp>) {
116 tensor = extractInsertOp.getDest();
117 } else {
118 llvm_unreachable("invalid op");
119 }
120 auto tensorType = cast<RankedTensorType>(tensor.getType());
121 auto rank = tensorType.getRank();
122 if (rank == 0) {
123 // Nothing to check for 0-d tensors.
124 return;
125 }
126
127 auto indices = extractInsertOp.getIndices();
128 auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
129 Value assertCond;
130 for (auto i : llvm::seq<int64_t>(0, rank)) {
131 Value dimOp = builder.createOrFold<tensor::DimOp>(loc, tensor, i);
132 Value inBounds =
133 generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
134 assertCond =
135 i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
136 : inBounds;
137 }
138 builder.create<cf::AssertOp>(
139 loc, assertCond,
140 RuntimeVerifiableOpInterface::generateErrorMessage(
141 op, "out-of-bounds access"));
142 }
143};
144
145struct ExtractSliceOpInterface
146 : public RuntimeVerifiableOpInterface::ExternalModel<
147 ExtractSliceOpInterface, ExtractSliceOp> {
148 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
149 Location loc) const {
150 auto extractSliceOp = cast<ExtractSliceOp>(op);
151 RankedTensorType sourceType = extractSliceOp.getSource().getType();
152
153 // For each dimension, assert that:
154 // 0 <= offset < dim_size
155 // 0 <= offset + (size - 1) * stride < dim_size
156 Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
157 Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
158 for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
159 Value offset = getValueOrCreateConstantIndexOp(
160 builder, loc, extractSliceOp.getMixedOffsets()[i]);
161 Value size = getValueOrCreateConstantIndexOp(
162 builder, loc, extractSliceOp.getMixedSizes()[i]);
163 Value stride = getValueOrCreateConstantIndexOp(
164 builder, loc, extractSliceOp.getMixedStrides()[i]);
165
166 // Verify that offset is in-bounds.
167 Value dimSize = builder.createOrFold<tensor::DimOp>(
168 loc, extractSliceOp.getSource(), i);
169 Value offsetInBounds =
170 generateInBoundsCheck(builder, loc, value: offset, lb: zero, ub: dimSize);
171 builder.create<cf::AssertOp>(
172 loc, offsetInBounds,
173 RuntimeVerifiableOpInterface::generateErrorMessage(
174 op, "offset " + std::to_string(i) + " is out-of-bounds"));
175
176 // Verify that slice does not run out-of-bounds.
177 Value sizeMinusOne = builder.create<arith::SubIOp>(loc, size, one);
178 Value sizeMinusOneTimesStride =
179 builder.create<arith::MulIOp>(loc, sizeMinusOne, stride);
180 Value lastPos =
181 builder.create<arith::AddIOp>(loc, offset, sizeMinusOneTimesStride);
182 Value lastPosInBounds =
183 generateInBoundsCheck(builder, loc, value: lastPos, lb: zero, ub: dimSize);
184 builder.create<cf::AssertOp>(
185 loc, lastPosInBounds,
186 RuntimeVerifiableOpInterface::generateErrorMessage(
187 op, "extract_slice runs out-of-bounds along dimension " +
188 std::to_string(i)));
189 }
190 }
191};
192} // namespace
193} // namespace tensor
194} // namespace mlir
195
196void mlir::tensor::registerRuntimeVerifiableOpInterfaceExternalModels(
197 DialectRegistry &registry) {
198 registry.addExtension(extensionFn: +[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
199 CastOp::attachInterface<CastOpInterface>(*ctx);
200 DimOp::attachInterface<DimOpInterface>(*ctx);
201 ExtractOp::attachInterface<ExtractInsertOpInterface<ExtractOp>>(*ctx);
202 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
203 InsertOp::attachInterface<ExtractInsertOpInterface<InsertOp>>(*ctx);
204
205 // Load additional dialects of which ops may get created.
206 ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
207 });
208}
209

source code of mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp