1//===- ComposeSubView.cpp - Combining composed subview ops ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns for combining composed subview ops (i.e. subview
10// of a subview becomes a single subview).
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/MemRef/Transforms/ComposeSubView.h"
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/IR/BuiltinAttributes.h"
18#include "mlir/IR/OpDefinition.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Transforms/DialectConversion.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22
23using namespace mlir;
24
25namespace {
26
27// Replaces a subview of a subview with a single subview(both static and dynamic
28// offsets are supported).
29struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
30 using OpRewritePattern::OpRewritePattern;
31
32 LogicalResult matchAndRewrite(memref::SubViewOp op,
33 PatternRewriter &rewriter) const override {
34 // 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that
35 // produces the input of the op we're rewriting (for 'SubViewOp' the input
36 // is called the "source" value). We can only combine them if both 'op' and
37 // 'sourceOp' are 'SubViewOp'.
38 auto sourceOp = op.getSource().getDefiningOp<memref::SubViewOp>();
39 if (!sourceOp)
40 return failure();
41
42 // A 'SubViewOp' can be "rank-reducing" by eliminating dimensions of the
43 // output memref that are statically known to be equal to 1. We do not
44 // allow 'sourceOp' to be a rank-reducing subview because then our two
45 // 'SubViewOp's would have different numbers of offset/size/stride
46 // parameters (just difficult to deal with, not impossible if we end up
47 // needing it).
48 if (sourceOp.getSourceType().getRank() != sourceOp.getType().getRank()) {
49 return failure();
50 }
51
52 // Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
53 SmallVector<OpFoldResult> offsets, sizes, strides,
54 opStrides = op.getMixedStrides(),
55 sourceStrides = sourceOp.getMixedStrides();
56
57 // The output stride in each dimension is equal to the product of the
58 // dimensions corresponding to source and op.
59 int64_t sourceStrideValue;
60 for (auto &&[opStride, sourceStride] :
61 llvm::zip(opStrides, sourceStrides)) {
62 Attribute opStrideAttr = dyn_cast_if_present<Attribute>(opStride);
63 Attribute sourceStrideAttr = dyn_cast_if_present<Attribute>(sourceStride);
64 if (!opStrideAttr || !sourceStrideAttr)
65 return failure();
66 sourceStrideValue = cast<IntegerAttr>(sourceStrideAttr).getInt();
67 strides.push_back(rewriter.getI64IntegerAttr(
68 cast<IntegerAttr>(opStrideAttr).getInt() * sourceStrideValue));
69 }
70
71 // The rules for calculating the new offsets and sizes are:
72 // * Multiple subview offsets for a given dimension compose additively.
73 // ("Offset by m and Stride by k" followed by "Offset by n" == "Offset by
74 // m + n * k")
75 // * Multiple sizes for a given dimension compose by taking the size of the
76 // final subview and ignoring the rest. ("Take m values" followed by "Take
77 // n values" == "Take n values") This size must also be the smallest one
78 // by definition (a subview needs to be the same size as or smaller than
79 // its source along each dimension; presumably subviews that are larger
80 // than their sources are disallowed by validation).
81 for (auto &&[opOffset, sourceOffset, sourceStride, opSize] :
82 llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
83 sourceOp.getMixedStrides(), op.getMixedSizes())) {
84 // We only support static sizes.
85 if (opSize.is<Value>()) {
86 return failure();
87 }
88 sizes.push_back(opSize);
89 Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
90 sourceOffsetAttr =
91 llvm::dyn_cast_if_present<Attribute>(sourceOffset),
92 sourceStrideAttr =
93 llvm::dyn_cast_if_present<Attribute>(sourceStride);
94 if (opOffsetAttr && sourceOffsetAttr) {
95
96 // If both offsets are static we can simply calculate the combined
97 // offset statically.
98 offsets.push_back(rewriter.getI64IntegerAttr(
99 cast<IntegerAttr>(opOffsetAttr).getInt() *
100 cast<IntegerAttr>(sourceStrideAttr).getInt() +
101 cast<IntegerAttr>(sourceOffsetAttr).getInt()));
102 } else {
103 AffineExpr expr;
104 SmallVector<Value> affineApplyOperands;
105
106 // Make 'expr' add 'sourceOffset'.
107 if (auto attr = llvm::dyn_cast_if_present<Attribute>(sourceOffset)) {
108 expr =
109 rewriter.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt());
110 } else {
111 expr = rewriter.getAffineSymbolExpr(affineApplyOperands.size());
112 affineApplyOperands.push_back(sourceOffset.get<Value>());
113 }
114
115 // Multiply 'opOffset' by 'sourceStride' and make the 'expr' add the
116 // result.
117 if (auto attr = llvm::dyn_cast_if_present<Attribute>(opOffset)) {
118 expr = expr + cast<IntegerAttr>(attr).getInt() *
119 cast<IntegerAttr>(sourceStrideAttr).getInt();
120 } else {
121 expr =
122 expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size()) *
123 cast<IntegerAttr>(sourceStrideAttr).getInt();
124 affineApplyOperands.push_back(opOffset.get<Value>());
125 }
126
127 AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
128 Value result = rewriter.create<affine::AffineApplyOp>(
129 op.getLoc(), map, affineApplyOperands);
130 offsets.push_back(result);
131 }
132 }
133
134 // This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any
135 // uses it can be removed by a (separate) dead code elimination pass.
136 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
137 op, op.getType(), sourceOp.getSource(), offsets, sizes, strides);
138 return success();
139 }
140};
141
142} // namespace
143
144void mlir::memref::populateComposeSubViewPatterns(RewritePatternSet &patterns,
145 MLIRContext *context) {
146 patterns.add<ComposeSubViewOpPattern>(arg&: context);
147}
148

source code of mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp