1//===- ComposeSubView.cpp - Combining composed subview ops ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns for combining composed subview ops (i.e. subview
10// of a subview becomes a single subview).
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/MemRef/Transforms/ComposeSubView.h"
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/IR/BuiltinAttributes.h"
18#include "mlir/IR/OpDefinition.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Transforms/DialectConversion.h"
21
22using namespace mlir;
23
24namespace {
25
26// Replaces a subview of a subview with a single subview(both static and dynamic
27// offsets are supported).
28struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
29 using OpRewritePattern::OpRewritePattern;
30
31 LogicalResult matchAndRewrite(memref::SubViewOp op,
32 PatternRewriter &rewriter) const override {
33 // 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that
34 // produces the input of the op we're rewriting (for 'SubViewOp' the input
35 // is called the "source" value). We can only combine them if both 'op' and
36 // 'sourceOp' are 'SubViewOp'.
37 auto sourceOp = op.getSource().getDefiningOp<memref::SubViewOp>();
38 if (!sourceOp)
39 return failure();
40
41 // A 'SubViewOp' can be "rank-reducing" by eliminating dimensions of the
42 // output memref that are statically known to be equal to 1. We do not
43 // allow 'sourceOp' to be a rank-reducing subview because then our two
44 // 'SubViewOp's would have different numbers of offset/size/stride
45 // parameters (just difficult to deal with, not impossible if we end up
46 // needing it).
47 if (sourceOp.getSourceType().getRank() != sourceOp.getType().getRank()) {
48 return failure();
49 }
50
51 // Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
52 SmallVector<OpFoldResult> offsets, sizes, strides,
53 opStrides = op.getMixedStrides(),
54 sourceStrides = sourceOp.getMixedStrides();
55
56 // The output stride in each dimension is equal to the product of the
57 // dimensions corresponding to source and op.
58 int64_t sourceStrideValue;
59 for (auto &&[opStride, sourceStride] :
60 llvm::zip(t&: opStrides, u&: sourceStrides)) {
61 Attribute opStrideAttr = dyn_cast_if_present<Attribute>(Val&: opStride);
62 Attribute sourceStrideAttr = dyn_cast_if_present<Attribute>(Val&: sourceStride);
63 if (!opStrideAttr || !sourceStrideAttr)
64 return failure();
65 sourceStrideValue = cast<IntegerAttr>(Val&: sourceStrideAttr).getInt();
66 strides.push_back(Elt: rewriter.getI64IntegerAttr(
67 value: cast<IntegerAttr>(Val&: opStrideAttr).getInt() * sourceStrideValue));
68 }
69
70 // The rules for calculating the new offsets and sizes are:
71 // * Multiple subview offsets for a given dimension compose additively.
72 // ("Offset by m and Stride by k" followed by "Offset by n" == "Offset by
73 // m + n * k")
74 // * Multiple sizes for a given dimension compose by taking the size of the
75 // final subview and ignoring the rest. ("Take m values" followed by "Take
76 // n values" == "Take n values") This size must also be the smallest one
77 // by definition (a subview needs to be the same size as or smaller than
78 // its source along each dimension; presumably subviews that are larger
79 // than their sources are disallowed by validation).
80 for (auto &&[opOffset, sourceOffset, sourceStride, opSize] :
81 llvm::zip(t: op.getMixedOffsets(), u: sourceOp.getMixedOffsets(),
82 args: sourceOp.getMixedStrides(), args: op.getMixedSizes())) {
83 // We only support static sizes.
84 if (isa<Value>(Val: opSize)) {
85 return failure();
86 }
87 sizes.push_back(Elt: opSize);
88 Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(Val&: opOffset),
89 sourceOffsetAttr =
90 llvm::dyn_cast_if_present<Attribute>(Val&: sourceOffset),
91 sourceStrideAttr =
92 llvm::dyn_cast_if_present<Attribute>(Val&: sourceStride);
93 if (opOffsetAttr && sourceOffsetAttr) {
94
95 // If both offsets are static we can simply calculate the combined
96 // offset statically.
97 offsets.push_back(Elt: rewriter.getI64IntegerAttr(
98 value: cast<IntegerAttr>(Val&: opOffsetAttr).getInt() *
99 cast<IntegerAttr>(Val&: sourceStrideAttr).getInt() +
100 cast<IntegerAttr>(Val&: sourceOffsetAttr).getInt()));
101 } else {
102 AffineExpr expr;
103 SmallVector<Value> affineApplyOperands;
104
105 // Make 'expr' add 'sourceOffset'.
106 if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val&: sourceOffset)) {
107 expr =
108 rewriter.getAffineConstantExpr(constant: cast<IntegerAttr>(Val&: attr).getInt());
109 } else {
110 expr = rewriter.getAffineSymbolExpr(position: affineApplyOperands.size());
111 affineApplyOperands.push_back(Elt: cast<Value>(Val&: sourceOffset));
112 }
113
114 // Multiply 'opOffset' by 'sourceStride' and make the 'expr' add the
115 // result.
116 if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val&: opOffset)) {
117 expr = expr + cast<IntegerAttr>(Val&: attr).getInt() *
118 cast<IntegerAttr>(Val&: sourceStrideAttr).getInt();
119 } else {
120 expr =
121 expr + rewriter.getAffineSymbolExpr(position: affineApplyOperands.size()) *
122 cast<IntegerAttr>(Val&: sourceStrideAttr).getInt();
123 affineApplyOperands.push_back(Elt: cast<Value>(Val&: opOffset));
124 }
125
126 AffineMap map = AffineMap::get(dimCount: 0, symbolCount: affineApplyOperands.size(), result: expr);
127 Value result = rewriter.create<affine::AffineApplyOp>(
128 location: op.getLoc(), args&: map, args&: affineApplyOperands);
129 offsets.push_back(Elt: result);
130 }
131 }
132
133 // This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any
134 // uses it can be removed by a (separate) dead code elimination pass.
135 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
136 op, args: op.getType(), args: sourceOp.getSource(), args&: offsets, args&: sizes, args&: strides);
137 return success();
138 }
139};
140
141} // namespace
142
143void mlir::memref::populateComposeSubViewPatterns(RewritePatternSet &patterns,
144 MLIRContext *context) {
145 patterns.add<ComposeSubViewOpPattern>(arg&: context);
146}
147

source code of mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp