1//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Arith/Utils/Utils.h"
14#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
15#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
18#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
19
20using namespace mlir;
21
22namespace mlir {
23namespace memref {
24namespace {
25/// Generate a runtime check for lb <= value < ub.
26Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
27 Value lb, Value ub) {
28 Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
29 location: loc, args: arith::CmpIPredicate::sge, args&: value, args&: lb);
30 Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
31 location: loc, args: arith::CmpIPredicate::slt, args&: value, args&: ub);
32 Value inBounds =
33 builder.createOrFold<arith::AndIOp>(location: loc, args&: inBounds1, args&: inBounds2);
34 return inBounds;
35}
36
37struct AssumeAlignmentOpInterface
38 : public RuntimeVerifiableOpInterface::ExternalModel<
39 AssumeAlignmentOpInterface, AssumeAlignmentOp> {
40 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
41 Location loc) const {
42 auto assumeOp = cast<AssumeAlignmentOp>(Val: op);
43 Value ptr = builder.create<ExtractAlignedPointerAsIndexOp>(
44 location: loc, args: assumeOp.getMemref());
45 Value rest = builder.create<arith::RemUIOp>(
46 location: loc, args&: ptr,
47 args: builder.create<arith::ConstantIndexOp>(location: loc, args: assumeOp.getAlignment()));
48 Value isAligned = builder.create<arith::CmpIOp>(
49 location: loc, args: arith::CmpIPredicate::eq, args&: rest,
50 args: builder.create<arith::ConstantIndexOp>(location: loc, args: 0));
51 builder.create<cf::AssertOp>(
52 location: loc, args&: isAligned,
53 args: RuntimeVerifiableOpInterface::generateErrorMessage(
54 op, msg: "memref is not aligned to " +
55 std::to_string(val: assumeOp.getAlignment())));
56 }
57};
58
59struct CastOpInterface
60 : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
61 CastOp> {
62 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
63 Location loc) const {
64 auto castOp = cast<CastOp>(Val: op);
65 auto srcType = cast<BaseMemRefType>(Val: castOp.getSource().getType());
66
67 // Nothing to check if the result is an unranked memref.
68 auto resultType = dyn_cast<MemRefType>(Val: castOp.getType());
69 if (!resultType)
70 return;
71
72 if (isa<UnrankedMemRefType>(Val: srcType)) {
73 // Check rank.
74 Value srcRank = builder.create<RankOp>(location: loc, args: castOp.getSource());
75 Value resultRank =
76 builder.create<arith::ConstantIndexOp>(location: loc, args: resultType.getRank());
77 Value isSameRank = builder.create<arith::CmpIOp>(
78 location: loc, args: arith::CmpIPredicate::eq, args&: srcRank, args&: resultRank);
79 builder.create<cf::AssertOp>(
80 location: loc, args&: isSameRank,
81 args: RuntimeVerifiableOpInterface::generateErrorMessage(op,
82 msg: "rank mismatch"));
83 }
84
85 // Get source offset and strides. We do not have an op to get offsets and
86 // strides from unranked memrefs, so cast the source to a type with fully
87 // dynamic layout, from which we can then extract the offset and strides.
88 // (Rank was already verified.)
89 int64_t dynamicOffset = ShapedType::kDynamic;
90 SmallVector<int64_t> dynamicShape(resultType.getRank(),
91 ShapedType::kDynamic);
92 auto stridedLayout = StridedLayoutAttr::get(context: builder.getContext(),
93 offset: dynamicOffset, strides: dynamicShape);
94 auto dynStridesType =
95 MemRefType::get(shape: dynamicShape, elementType: resultType.getElementType(),
96 layout: stridedLayout, memorySpace: resultType.getMemorySpace());
97 Value helperCast =
98 builder.create<CastOp>(location: loc, args&: dynStridesType, args: castOp.getSource());
99 auto metadataOp = builder.create<ExtractStridedMetadataOp>(location: loc, args&: helperCast);
100
101 // Check dimension sizes.
102 for (const auto &it : llvm::enumerate(First: resultType.getShape())) {
103 // Static dim size -> static/dynamic dim size does not need verification.
104 if (auto rankedSrcType = dyn_cast<MemRefType>(Val&: srcType))
105 if (!rankedSrcType.isDynamicDim(idx: it.index()))
106 continue;
107
108 // Static/dynamic dim size -> dynamic dim size does not need verification.
109 if (resultType.isDynamicDim(idx: it.index()))
110 continue;
111
112 Value srcDimSz =
113 builder.create<DimOp>(location: loc, args: castOp.getSource(), args: it.index());
114 Value resultDimSz =
115 builder.create<arith::ConstantIndexOp>(location: loc, args: it.value());
116 Value isSameSz = builder.create<arith::CmpIOp>(
117 location: loc, args: arith::CmpIPredicate::eq, args&: srcDimSz, args&: resultDimSz);
118 builder.create<cf::AssertOp>(
119 location: loc, args&: isSameSz,
120 args: RuntimeVerifiableOpInterface::generateErrorMessage(
121 op, msg: "size mismatch of dim " + std::to_string(val: it.index())));
122 }
123
124 // Get result offset and strides.
125 int64_t resultOffset;
126 SmallVector<int64_t> resultStrides;
127 if (failed(Result: resultType.getStridesAndOffset(strides&: resultStrides, offset&: resultOffset)))
128 return;
129
130 // Check offset.
131 if (resultOffset != ShapedType::kDynamic) {
132 // Static/dynamic offset -> dynamic offset does not need verification.
133 Value srcOffset = metadataOp.getResult(i: 1);
134 Value resultOffsetVal =
135 builder.create<arith::ConstantIndexOp>(location: loc, args&: resultOffset);
136 Value isSameOffset = builder.create<arith::CmpIOp>(
137 location: loc, args: arith::CmpIPredicate::eq, args&: srcOffset, args&: resultOffsetVal);
138 builder.create<cf::AssertOp>(
139 location: loc, args&: isSameOffset,
140 args: RuntimeVerifiableOpInterface::generateErrorMessage(
141 op, msg: "offset mismatch"));
142 }
143
144 // Check strides.
145 for (const auto &it : llvm::enumerate(First&: resultStrides)) {
146 // Static/dynamic stride -> dynamic stride does not need verification.
147 if (it.value() == ShapedType::kDynamic)
148 continue;
149
150 Value srcStride =
151 metadataOp.getResult(i: 2 + resultType.getRank() + it.index());
152 Value resultStrideVal =
153 builder.create<arith::ConstantIndexOp>(location: loc, args&: it.value());
154 Value isSameStride = builder.create<arith::CmpIOp>(
155 location: loc, args: arith::CmpIPredicate::eq, args&: srcStride, args&: resultStrideVal);
156 builder.create<cf::AssertOp>(
157 location: loc, args&: isSameStride,
158 args: RuntimeVerifiableOpInterface::generateErrorMessage(
159 op, msg: "stride mismatch of dim " + std::to_string(val: it.index())));
160 }
161 }
162};
163
164struct CopyOpInterface
165 : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
166 CopyOp> {
167 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
168 Location loc) const {
169 auto copyOp = cast<CopyOp>(Val: op);
170 BaseMemRefType sourceType = copyOp.getSource().getType();
171 BaseMemRefType targetType = copyOp.getTarget().getType();
172 auto rankedSourceType = dyn_cast<MemRefType>(Val&: sourceType);
173 auto rankedTargetType = dyn_cast<MemRefType>(Val&: targetType);
174
175 // TODO: Verification for unranked memrefs is not supported yet.
176 if (!rankedSourceType || !rankedTargetType)
177 return;
178
179 assert(sourceType.getRank() == targetType.getRank() && "rank mismatch");
180 for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
181 // Fully static dimensions in both source and target operand are already
182 // verified by the op verifier.
183 if (!rankedSourceType.isDynamicDim(idx: i) &&
184 !rankedTargetType.isDynamicDim(idx: i))
185 continue;
186 auto getDimSize = [&](Value memRef, MemRefType type,
187 int64_t dim) -> Value {
188 return type.isDynamicDim(idx: dim)
189 ? builder.create<DimOp>(location: loc, args&: memRef, args&: dim).getResult()
190 : builder
191 .create<arith::ConstantIndexOp>(location: loc,
192 args: type.getDimSize(idx: dim))
193 .getResult();
194 };
195 Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i);
196 Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
197 Value sameDimSize = builder.create<arith::CmpIOp>(
198 location: loc, args: arith::CmpIPredicate::eq, args&: sourceDim, args&: targetDim);
199 builder.create<cf::AssertOp>(
200 location: loc, args&: sameDimSize,
201 args: RuntimeVerifiableOpInterface::generateErrorMessage(
202 op, msg: "size of " + std::to_string(val: i) +
203 "-th source/target dim does not match"));
204 }
205 }
206};
207
208struct DimOpInterface
209 : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
210 DimOp> {
211 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
212 Location loc) const {
213 auto dimOp = cast<DimOp>(Val: op);
214 Value rank = builder.create<RankOp>(location: loc, args: dimOp.getSource());
215 Value zero = builder.create<arith::ConstantIndexOp>(location: loc, args: 0);
216 builder.create<cf::AssertOp>(
217 location: loc, args: generateInBoundsCheck(builder, loc, value: dimOp.getIndex(), lb: zero, ub: rank),
218 args: RuntimeVerifiableOpInterface::generateErrorMessage(
219 op, msg: "index is out of bounds"));
220 }
221};
222
223/// Verifies that the indices on load/store ops are in-bounds of the memref's
224/// index space: 0 <= index#i < dim#i
225template <typename LoadStoreOp>
226struct LoadStoreOpInterface
227 : public RuntimeVerifiableOpInterface::ExternalModel<
228 LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
229 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
230 Location loc) const {
231 auto loadStoreOp = cast<LoadStoreOp>(op);
232
233 auto memref = loadStoreOp.getMemref();
234 auto rank = memref.getType().getRank();
235 if (rank == 0) {
236 return;
237 }
238 auto indices = loadStoreOp.getIndices();
239
240 auto zero = builder.create<arith::ConstantIndexOp>(location: loc, args: 0);
241 Value assertCond;
242 for (auto i : llvm::seq<int64_t>(0, rank)) {
243 Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
244 Value inBounds =
245 generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
246 assertCond =
247 i > 0 ? builder.createOrFold<arith::AndIOp>(location: loc, args&: assertCond, args&: inBounds)
248 : inBounds;
249 }
250 builder.create<cf::AssertOp>(
251 location: loc, args&: assertCond,
252 args: RuntimeVerifiableOpInterface::generateErrorMessage(
253 op, msg: "out-of-bounds access"));
254 }
255};
256
257struct SubViewOpInterface
258 : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
259 SubViewOp> {
260 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
261 Location loc) const {
262 auto subView = cast<SubViewOp>(Val: op);
263 MemRefType sourceType = subView.getSource().getType();
264
265 // For each dimension, assert that:
266 // 0 <= offset < dim_size
267 // 0 <= offset + (size - 1) * stride < dim_size
268 Value zero = builder.create<arith::ConstantIndexOp>(location: loc, args: 0);
269 Value one = builder.create<arith::ConstantIndexOp>(location: loc, args: 1);
270 auto metadataOp =
271 builder.create<ExtractStridedMetadataOp>(location: loc, args: subView.getSource());
272 for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
273 Value offset = getValueOrCreateConstantIndexOp(
274 b&: builder, loc, ofr: subView.getMixedOffsets()[i]);
275 Value size = getValueOrCreateConstantIndexOp(b&: builder, loc,
276 ofr: subView.getMixedSizes()[i]);
277 Value stride = getValueOrCreateConstantIndexOp(
278 b&: builder, loc, ofr: subView.getMixedStrides()[i]);
279
280 // Verify that offset is in-bounds.
281 Value dimSize = metadataOp.getSizes()[i];
282 Value offsetInBounds =
283 generateInBoundsCheck(builder, loc, value: offset, lb: zero, ub: dimSize);
284 builder.create<cf::AssertOp>(
285 location: loc, args&: offsetInBounds,
286 args: RuntimeVerifiableOpInterface::generateErrorMessage(
287 op, msg: "offset " + std::to_string(val: i) + " is out-of-bounds"));
288
289 // Verify that slice does not run out-of-bounds.
290 Value sizeMinusOne = builder.create<arith::SubIOp>(location: loc, args&: size, args&: one);
291 Value sizeMinusOneTimesStride =
292 builder.create<arith::MulIOp>(location: loc, args&: sizeMinusOne, args&: stride);
293 Value lastPos =
294 builder.create<arith::AddIOp>(location: loc, args&: offset, args&: sizeMinusOneTimesStride);
295 Value lastPosInBounds =
296 generateInBoundsCheck(builder, loc, value: lastPos, lb: zero, ub: dimSize);
297 builder.create<cf::AssertOp>(
298 location: loc, args&: lastPosInBounds,
299 args: RuntimeVerifiableOpInterface::generateErrorMessage(
300 op, msg: "subview runs out-of-bounds along dimension " +
301 std::to_string(val: i)));
302 }
303 }
304};
305
306struct ExpandShapeOpInterface
307 : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
308 ExpandShapeOp> {
309 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
310 Location loc) const {
311 auto expandShapeOp = cast<ExpandShapeOp>(Val: op);
312
313 // Verify that the expanded dim sizes are a product of the collapsed dim
314 // size.
315 for (const auto &it :
316 llvm::enumerate(First: expandShapeOp.getReassociationIndices())) {
317 Value srcDimSz =
318 builder.create<DimOp>(location: loc, args: expandShapeOp.getSrc(), args: it.index());
319 int64_t groupSz = 1;
320 bool foundDynamicDim = false;
321 for (int64_t resultDim : it.value()) {
322 if (expandShapeOp.getResultType().isDynamicDim(idx: resultDim)) {
323 // Keep this assert here in case the op is extended in the future.
324 assert(!foundDynamicDim &&
325 "more than one dynamic dim found in reassoc group");
326 (void)foundDynamicDim;
327 foundDynamicDim = true;
328 continue;
329 }
330 groupSz *= expandShapeOp.getResultType().getDimSize(idx: resultDim);
331 }
332 Value staticResultDimSz =
333 builder.create<arith::ConstantIndexOp>(location: loc, args&: groupSz);
334 // staticResultDimSz must divide srcDimSz evenly.
335 Value mod =
336 builder.create<arith::RemSIOp>(location: loc, args&: srcDimSz, args&: staticResultDimSz);
337 Value isModZero = builder.create<arith::CmpIOp>(
338 location: loc, args: arith::CmpIPredicate::eq, args&: mod,
339 args: builder.create<arith::ConstantIndexOp>(location: loc, args: 0));
340 builder.create<cf::AssertOp>(
341 location: loc, args&: isModZero,
342 args: RuntimeVerifiableOpInterface::generateErrorMessage(
343 op, msg: "static result dims in reassoc group do not "
344 "divide src dim evenly"));
345 }
346 }
347};
348} // namespace
349} // namespace memref
350} // namespace mlir
351
352void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
353 DialectRegistry &registry) {
354 registry.addExtension(extensionFn: +[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
355 AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(context&: *ctx);
356 AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(context&: *ctx);
357 CastOp::attachInterface<CastOpInterface>(context&: *ctx);
358 CopyOp::attachInterface<CopyOpInterface>(context&: *ctx);
359 DimOp::attachInterface<DimOpInterface>(context&: *ctx);
360 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(context&: *ctx);
361 GenericAtomicRMWOp::attachInterface<
362 LoadStoreOpInterface<GenericAtomicRMWOp>>(context&: *ctx);
363 LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(context&: *ctx);
364 StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(context&: *ctx);
365 SubViewOp::attachInterface<SubViewOpInterface>(context&: *ctx);
366 // Note: There is nothing to verify for ReinterpretCastOp.
367
368 // Load additional dialects of which ops may get created.
369 ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
370 cf::ControlFlowDialect>();
371 });
372}
373

source code of mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp