1//===- InferShapeTest.cpp - unit tests for shape inference ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/MemRef/IR/MemRef.h"
10#include "mlir/IR/AffineMap.h"
11#include "mlir/IR/Builders.h"
12#include "mlir/IR/BuiltinTypes.h"
13#include "gtest/gtest.h"
14
15using namespace mlir;
16using namespace mlir::memref;
17
18// Source memref has identity layout.
19TEST(InferShapeTest, inferRankReducedShapeIdentity) {
20 MLIRContext ctx;
21 OpBuilder b(&ctx);
22 auto sourceMemref = MemRefType::get(shape: {10, 5}, elementType: b.getIndexType());
23 auto reducedType = SubViewOp::inferRankReducedResultType(
24 /*resultShape=*/{2}, sourceMemRefType: sourceMemref, staticOffsets: {2, 3}, staticSizes: {1, 2}, staticStrides: {1, 1});
25 auto expectedType = MemRefType::get(
26 shape: {2}, elementType: b.getIndexType(),
27 layout: StridedLayoutAttr::get(context: &ctx, /*offset=*/13, /*strides=*/{1}));
28 EXPECT_EQ(reducedType, expectedType);
29}
30
31// Source memref has non-identity layout.
32TEST(InferShapeTest, inferRankReducedShapeNonIdentity) {
33 MLIRContext ctx;
34 OpBuilder b(&ctx);
35 AffineExpr dim0, dim1;
36 bindDims(ctx: &ctx, exprs&: dim0, exprs&: dim1);
37 auto sourceMemref = MemRefType::get(shape: {10, 5}, elementType: b.getIndexType(),
38 map: AffineMap::get(dimCount: 2, symbolCount: 0, result: 1000 * dim0 + dim1));
39 auto reducedType = SubViewOp::inferRankReducedResultType(
40 /*resultShape=*/{2}, sourceMemRefType: sourceMemref, staticOffsets: {2, 3}, staticSizes: {1, 2}, staticStrides: {1, 1});
41 auto expectedType = MemRefType::get(
42 shape: {2}, elementType: b.getIndexType(),
43 layout: StridedLayoutAttr::get(context: &ctx, /*offset=*/2003, /*strides=*/{1}));
44 EXPECT_EQ(reducedType, expectedType);
45}
46
47TEST(InferShapeTest, inferRankReducedShapeToScalar) {
48 MLIRContext ctx;
49 OpBuilder b(&ctx);
50 AffineExpr dim0, dim1;
51 bindDims(ctx: &ctx, exprs&: dim0, exprs&: dim1);
52 auto sourceMemref = MemRefType::get(shape: {10, 5}, elementType: b.getIndexType(),
53 map: AffineMap::get(dimCount: 2, symbolCount: 0, result: 1000 * dim0 + dim1));
54 auto reducedType = SubViewOp::inferRankReducedResultType(
55 /*resultShape=*/{}, sourceMemRefType: sourceMemref, staticOffsets: {2, 3}, staticSizes: {1, 1}, staticStrides: {1, 1});
56 auto expectedType = MemRefType::get(
57 shape: {}, elementType: b.getIndexType(),
58 layout: StridedLayoutAttr::get(context: &ctx, /*offset=*/2003, /*strides=*/{}));
59 EXPECT_EQ(reducedType, expectedType);
60}
61

source code of mlir/unittests/Dialect/MemRef/InferShapeTest.cpp