1//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Arith/Utils/Utils.h"
14#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
15#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
18#include "mlir/Dialect/Utils/IndexingUtils.h"
19#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
20
21using namespace mlir;
22
23namespace mlir {
24namespace memref {
25namespace {
26/// Generate a runtime check for lb <= value < ub.
27Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
28 Value lb, Value ub) {
29 Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
30 loc, arith::CmpIPredicate::sge, value, lb);
31 Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
32 loc, arith::CmpIPredicate::slt, value, ub);
33 Value inBounds =
34 builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
35 return inBounds;
36}
37
38struct AssumeAlignmentOpInterface
39 : public RuntimeVerifiableOpInterface::ExternalModel<
40 AssumeAlignmentOpInterface, AssumeAlignmentOp> {
41 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
42 Location loc) const {
43 auto assumeOp = cast<AssumeAlignmentOp>(op);
44 Value ptr = builder.create<ExtractAlignedPointerAsIndexOp>(
45 loc, assumeOp.getMemref());
46 Value rest = builder.create<arith::RemUIOp>(
47 loc, ptr,
48 builder.create<arith::ConstantIndexOp>(loc, assumeOp.getAlignment()));
49 Value isAligned = builder.create<arith::CmpIOp>(
50 loc, arith::CmpIPredicate::eq, rest,
51 builder.create<arith::ConstantIndexOp>(loc, 0));
52 builder.create<cf::AssertOp>(
53 loc, isAligned,
54 RuntimeVerifiableOpInterface::generateErrorMessage(
55 op, "memref is not aligned to " +
56 std::to_string(assumeOp.getAlignment())));
57 }
58};
59
60struct CastOpInterface
61 : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
62 CastOp> {
63 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
64 Location loc) const {
65 auto castOp = cast<CastOp>(op);
66 auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
67
68 // Nothing to check if the result is an unranked memref.
69 auto resultType = dyn_cast<MemRefType>(castOp.getType());
70 if (!resultType)
71 return;
72
73 if (isa<UnrankedMemRefType>(srcType)) {
74 // Check rank.
75 Value srcRank = builder.create<RankOp>(loc, castOp.getSource());
76 Value resultRank =
77 builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
78 Value isSameRank = builder.create<arith::CmpIOp>(
79 loc, arith::CmpIPredicate::eq, srcRank, resultRank);
80 builder.create<cf::AssertOp>(
81 loc, isSameRank,
82 RuntimeVerifiableOpInterface::generateErrorMessage(op,
83 "rank mismatch"));
84 }
85
86 // Get source offset and strides. We do not have an op to get offsets and
87 // strides from unranked memrefs, so cast the source to a type with fully
88 // dynamic layout, from which we can then extract the offset and strides.
89 // (Rank was already verified.)
90 int64_t dynamicOffset = ShapedType::kDynamic;
91 SmallVector<int64_t> dynamicShape(resultType.getRank(),
92 ShapedType::kDynamic);
93 auto stridedLayout = StridedLayoutAttr::get(builder.getContext(),
94 dynamicOffset, dynamicShape);
95 auto dynStridesType =
96 MemRefType::get(dynamicShape, resultType.getElementType(),
97 stridedLayout, resultType.getMemorySpace());
98 Value helperCast =
99 builder.create<CastOp>(loc, dynStridesType, castOp.getSource());
100 auto metadataOp = builder.create<ExtractStridedMetadataOp>(loc, helperCast);
101
102 // Check dimension sizes.
103 for (const auto &it : llvm::enumerate(resultType.getShape())) {
104 // Static dim size -> static/dynamic dim size does not need verification.
105 if (auto rankedSrcType = dyn_cast<MemRefType>(srcType))
106 if (!rankedSrcType.isDynamicDim(it.index()))
107 continue;
108
109 // Static/dynamic dim size -> dynamic dim size does not need verification.
110 if (resultType.isDynamicDim(it.index()))
111 continue;
112
113 Value srcDimSz =
114 builder.create<DimOp>(loc, castOp.getSource(), it.index());
115 Value resultDimSz =
116 builder.create<arith::ConstantIndexOp>(loc, it.value());
117 Value isSameSz = builder.create<arith::CmpIOp>(
118 loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
119 builder.create<cf::AssertOp>(
120 loc, isSameSz,
121 RuntimeVerifiableOpInterface::generateErrorMessage(
122 op, "size mismatch of dim " + std::to_string(it.index())));
123 }
124
125 // Get result offset and strides.
126 int64_t resultOffset;
127 SmallVector<int64_t> resultStrides;
128 if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
129 return;
130
131 // Check offset.
132 if (resultOffset != ShapedType::kDynamic) {
133 // Static/dynamic offset -> dynamic offset does not need verification.
134 Value srcOffset = metadataOp.getResult(1);
135 Value resultOffsetVal =
136 builder.create<arith::ConstantIndexOp>(loc, resultOffset);
137 Value isSameOffset = builder.create<arith::CmpIOp>(
138 loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
139 builder.create<cf::AssertOp>(
140 loc, isSameOffset,
141 RuntimeVerifiableOpInterface::generateErrorMessage(
142 op, "offset mismatch"));
143 }
144
145 // Check strides.
146 for (const auto &it : llvm::enumerate(resultStrides)) {
147 // Static/dynamic stride -> dynamic stride does not need verification.
148 if (it.value() == ShapedType::kDynamic)
149 continue;
150
151 Value srcStride =
152 metadataOp.getResult(2 + resultType.getRank() + it.index());
153 Value resultStrideVal =
154 builder.create<arith::ConstantIndexOp>(loc, it.value());
155 Value isSameStride = builder.create<arith::CmpIOp>(
156 loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
157 builder.create<cf::AssertOp>(
158 loc, isSameStride,
159 RuntimeVerifiableOpInterface::generateErrorMessage(
160 op, "stride mismatch of dim " + std::to_string(it.index())));
161 }
162 }
163};
164
165struct CopyOpInterface
166 : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
167 CopyOp> {
168 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
169 Location loc) const {
170 auto copyOp = cast<CopyOp>(op);
171 BaseMemRefType sourceType = copyOp.getSource().getType();
172 BaseMemRefType targetType = copyOp.getTarget().getType();
173 auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
174 auto rankedTargetType = dyn_cast<MemRefType>(targetType);
175
176 // TODO: Verification for unranked memrefs is not supported yet.
177 if (!rankedSourceType || !rankedTargetType)
178 return;
179
180 assert(sourceType.getRank() == targetType.getRank() && "rank mismatch");
181 for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
182 // Fully static dimensions in both source and target operand are already
183 // verified by the op verifier.
184 if (!rankedSourceType.isDynamicDim(i) &&
185 !rankedTargetType.isDynamicDim(i))
186 continue;
187 auto getDimSize = [&](Value memRef, MemRefType type,
188 int64_t dim) -> Value {
189 return type.isDynamicDim(dim)
190 ? builder.create<DimOp>(loc, memRef, dim).getResult()
191 : builder
192 .create<arith::ConstantIndexOp>(loc,
193 type.getDimSize(dim))
194 .getResult();
195 };
196 Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i);
197 Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
198 Value sameDimSize = builder.create<arith::CmpIOp>(
199 loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
200 builder.create<cf::AssertOp>(
201 loc, sameDimSize,
202 RuntimeVerifiableOpInterface::generateErrorMessage(
203 op, "size of " + std::to_string(i) +
204 "-th source/target dim does not match"));
205 }
206 }
207};
208
209struct DimOpInterface
210 : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
211 DimOp> {
212 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
213 Location loc) const {
214 auto dimOp = cast<DimOp>(op);
215 Value rank = builder.create<RankOp>(loc, dimOp.getSource());
216 Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
217 builder.create<cf::AssertOp>(
218 loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
219 RuntimeVerifiableOpInterface::generateErrorMessage(
220 op, "index is out of bounds"));
221 }
222};
223
224/// Verifies that the indices on load/store ops are in-bounds of the memref's
225/// index space: 0 <= index#i < dim#i
226template <typename LoadStoreOp>
227struct LoadStoreOpInterface
228 : public RuntimeVerifiableOpInterface::ExternalModel<
229 LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
230 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
231 Location loc) const {
232 auto loadStoreOp = cast<LoadStoreOp>(op);
233
234 auto memref = loadStoreOp.getMemref();
235 auto rank = memref.getType().getRank();
236 if (rank == 0) {
237 return;
238 }
239 auto indices = loadStoreOp.getIndices();
240
241 auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
242 Value assertCond;
243 for (auto i : llvm::seq<int64_t>(0, rank)) {
244 Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
245 Value inBounds =
246 generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
247 assertCond =
248 i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
249 : inBounds;
250 }
251 builder.create<cf::AssertOp>(
252 loc, assertCond,
253 RuntimeVerifiableOpInterface::generateErrorMessage(
254 op, "out-of-bounds access"));
255 }
256};
257
258struct SubViewOpInterface
259 : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
260 SubViewOp> {
261 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
262 Location loc) const {
263 auto subView = cast<SubViewOp>(op);
264 MemRefType sourceType = subView.getSource().getType();
265
266 // For each dimension, assert that:
267 // 0 <= offset < dim_size
268 // 0 <= offset + (size - 1) * stride < dim_size
269 Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
270 Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
271 auto metadataOp =
272 builder.create<ExtractStridedMetadataOp>(loc, subView.getSource());
273 for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
274 Value offset = getValueOrCreateConstantIndexOp(
275 builder, loc, subView.getMixedOffsets()[i]);
276 Value size = getValueOrCreateConstantIndexOp(builder, loc,
277 subView.getMixedSizes()[i]);
278 Value stride = getValueOrCreateConstantIndexOp(
279 builder, loc, subView.getMixedStrides()[i]);
280
281 // Verify that offset is in-bounds.
282 Value dimSize = metadataOp.getSizes()[i];
283 Value offsetInBounds =
284 generateInBoundsCheck(builder, loc, value: offset, lb: zero, ub: dimSize);
285 builder.create<cf::AssertOp>(
286 loc, offsetInBounds,
287 RuntimeVerifiableOpInterface::generateErrorMessage(
288 op, "offset " + std::to_string(i) + " is out-of-bounds"));
289
290 // Verify that slice does not run out-of-bounds.
291 Value sizeMinusOne = builder.create<arith::SubIOp>(loc, size, one);
292 Value sizeMinusOneTimesStride =
293 builder.create<arith::MulIOp>(loc, sizeMinusOne, stride);
294 Value lastPos =
295 builder.create<arith::AddIOp>(loc, offset, sizeMinusOneTimesStride);
296 Value lastPosInBounds =
297 generateInBoundsCheck(builder, loc, value: lastPos, lb: zero, ub: dimSize);
298 builder.create<cf::AssertOp>(
299 loc, lastPosInBounds,
300 RuntimeVerifiableOpInterface::generateErrorMessage(
301 op, "subview runs out-of-bounds along dimension " +
302 std::to_string(i)));
303 }
304 }
305};
306
307struct ExpandShapeOpInterface
308 : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
309 ExpandShapeOp> {
310 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
311 Location loc) const {
312 auto expandShapeOp = cast<ExpandShapeOp>(op);
313
314 // Verify that the expanded dim sizes are a product of the collapsed dim
315 // size.
316 for (const auto &it :
317 llvm::enumerate(expandShapeOp.getReassociationIndices())) {
318 Value srcDimSz =
319 builder.create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
320 int64_t groupSz = 1;
321 bool foundDynamicDim = false;
322 for (int64_t resultDim : it.value()) {
323 if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
324 // Keep this assert here in case the op is extended in the future.
325 assert(!foundDynamicDim &&
326 "more than one dynamic dim found in reassoc group");
327 (void)foundDynamicDim;
328 foundDynamicDim = true;
329 continue;
330 }
331 groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
332 }
333 Value staticResultDimSz =
334 builder.create<arith::ConstantIndexOp>(loc, groupSz);
335 // staticResultDimSz must divide srcDimSz evenly.
336 Value mod =
337 builder.create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
338 Value isModZero = builder.create<arith::CmpIOp>(
339 loc, arith::CmpIPredicate::eq, mod,
340 builder.create<arith::ConstantIndexOp>(loc, 0));
341 builder.create<cf::AssertOp>(
342 loc, isModZero,
343 RuntimeVerifiableOpInterface::generateErrorMessage(
344 op, "static result dims in reassoc group do not "
345 "divide src dim evenly"));
346 }
347 }
348};
349} // namespace
350} // namespace memref
351} // namespace mlir
352
353void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
354 DialectRegistry &registry) {
355 registry.addExtension(extensionFn: +[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
356 AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
357 AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
358 CastOp::attachInterface<CastOpInterface>(*ctx);
359 CopyOp::attachInterface<CopyOpInterface>(*ctx);
360 DimOp::attachInterface<DimOpInterface>(*ctx);
361 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
362 GenericAtomicRMWOp::attachInterface<
363 LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx);
364 LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
365 StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
366 SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
367 // Note: There is nothing to verify for ReinterpretCastOp.
368
369 // Load additional dialects of which ops may get created.
370 ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
371 cf::ControlFlowDialect>();
372 });
373}
374

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp