1//===- LowerVectorGather.cpp - Lower 'vector.gather' operation ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.gather' operation.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Arith/Utils/Utils.h"
17#include "mlir/Dialect/Linalg/IR/Linalg.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/Dialect/Tensor/IR/Tensor.h"
21#include "mlir/Dialect/Utils/IndexingUtils.h"
22#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23#include "mlir/Dialect/Vector/IR/VectorOps.h"
24#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
25#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26#include "mlir/IR/BuiltinAttributeInterfaces.h"
27#include "mlir/IR/BuiltinTypes.h"
28#include "mlir/IR/ImplicitLocOpBuilder.h"
29#include "mlir/IR/Location.h"
30#include "mlir/IR/Matchers.h"
31#include "mlir/IR/PatternMatch.h"
32#include "mlir/IR/TypeUtilities.h"
33#include "mlir/Interfaces/VectorInterfaces.h"
34#include "mlir/Support/LogicalResult.h"
35
36#define DEBUG_TYPE "vector-broadcast-lowering"
37
38using namespace mlir;
39using namespace mlir::vector;
40
41namespace {
42/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
43/// outermost dimension. For example:
44/// ```
45/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
46/// ... into vector<2x3xf32>
47///
48/// ==>
49///
50/// %0 = arith.constant dense<0.0> : vector<2x3xf32>
51/// %g0 = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ...
52/// %1 = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32>
53/// %g1 = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ...
54/// %g = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32>
55/// ```
56///
57/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
58struct FlattenGather : OpRewritePattern<vector::GatherOp> {
59 using OpRewritePattern::OpRewritePattern;
60
61 LogicalResult matchAndRewrite(vector::GatherOp op,
62 PatternRewriter &rewriter) const override {
63 VectorType resultTy = op.getType();
64 if (resultTy.getRank() < 2)
65 return rewriter.notifyMatchFailure(op, "already flat");
66
67 Location loc = op.getLoc();
68 Value indexVec = op.getIndexVec();
69 Value maskVec = op.getMask();
70 Value passThruVec = op.getPassThru();
71
72 Value result = rewriter.create<arith::ConstantOp>(
73 loc, resultTy, rewriter.getZeroAttr(resultTy));
74
75 Type subTy = VectorType::get(resultTy.getShape().drop_front(),
76 resultTy.getElementType());
77
78 for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
79 int64_t thisIdx[1] = {i};
80
81 Value indexSubVec =
82 rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
83 Value maskSubVec =
84 rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx);
85 Value passThruSubVec =
86 rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx);
87 Value subGather = rewriter.create<vector::GatherOp>(
88 loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
89 passThruSubVec);
90 result =
91 rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx);
92 }
93
94 rewriter.replaceOp(op, result);
95 return success();
96 }
97};
98
99/// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
100/// MemRef with updated indices that model the strided access.
101///
102/// ```mlir
103/// %subview = memref.subview %M (...)
104/// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
105/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
106/// ```
107/// ==>
108/// ```mlir
109/// %collapse_shape = memref.collapse_shape %M (...)
110/// : memref<100x3xf32> into memref<300xf32>
111/// %new_idxs = arith.muli %idxs, %c3 : vector<4xindex>
112/// %gather = vector.gather %collapse_shape[%new_idxs] (...)
113/// : memref<300xf32> (...)
114/// ```
115///
116/// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
117/// but should be fairly straightforward to extend beyond that.
118struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
119 using OpRewritePattern::OpRewritePattern;
120
121 LogicalResult matchAndRewrite(vector::GatherOp op,
122 PatternRewriter &rewriter) const override {
123 Value base = op.getBase();
124
125 // TODO: Strided accesses might be coming from other ops as well
126 auto subview = base.getDefiningOp<memref::SubViewOp>();
127 if (!subview)
128 return failure();
129
130 auto sourceType = subview.getSource().getType();
131
132 // TODO: Allow ranks > 2.
133 if (sourceType.getRank() != 2)
134 return failure();
135
136 // Get strides
137 auto layout = subview.getResult().getType().getLayout();
138 auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
139 if (!stridedLayoutAttr)
140 return failure();
141
142 // TODO: Allow the access to be strided in multiple dimensions.
143 if (stridedLayoutAttr.getStrides().size() != 1)
144 return failure();
145
146 int64_t srcTrailingDim = sourceType.getShape().back();
147
148 // Assume that the stride matches the trailing dimension of the source
149 // memref.
150 // TODO: Relax this assumption.
151 if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
152 return failure();
153
154 // 1. Collapse the input memref so that it's "flat".
155 SmallVector<ReassociationIndices> reassoc = {{0, 1}};
156 Value collapsed = rewriter.create<memref::CollapseShapeOp>(
157 op.getLoc(), subview.getSource(), reassoc);
158
159 // 2. Generate new gather indices that will model the
160 // strided access.
161 IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
162 VectorType vType = op.getIndexVec().getType();
163 Value mulCst = rewriter.create<arith::ConstantOp>(
164 op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
165
166 Value newIdxs =
167 rewriter.create<arith::MulIOp>(op.getLoc(), op.getIndexVec(), mulCst);
168
169 // 3. Create an updated gather op with the collapsed input memref and the
170 // updated indices.
171 Value newGather = rewriter.create<vector::GatherOp>(
172 op.getLoc(), op.getResult().getType(), collapsed, op.getIndices(),
173 newIdxs, op.getMask(), op.getPassThru());
174 rewriter.replaceOp(op, newGather);
175
176 return success();
177 }
178};
179
180/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
181/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
182/// loads/extracts are made conditional using `scf.if` ops.
183struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
184 using OpRewritePattern::OpRewritePattern;
185
186 LogicalResult matchAndRewrite(vector::GatherOp op,
187 PatternRewriter &rewriter) const override {
188 VectorType resultTy = op.getType();
189 if (resultTy.getRank() != 1)
190 return rewriter.notifyMatchFailure(op, "unsupported rank");
191
192 Location loc = op.getLoc();
193 Type elemTy = resultTy.getElementType();
194 // Vector type with a single element. Used to generate `vector.loads`.
195 VectorType elemVecTy = VectorType::get({1}, elemTy);
196
197 Value condMask = op.getMask();
198 Value base = op.getBase();
199
200 // vector.load requires the most minor memref dim to have unit stride
201 if (auto memType = dyn_cast<MemRefType>(base.getType())) {
202 if (auto stridesAttr =
203 dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
204 if (stridesAttr.getStrides().back() != 1)
205 return failure();
206 }
207 }
208
209 Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
210 loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
211 op.getIndexVec());
212 auto baseOffsets = llvm::to_vector(op.getIndices());
213 Value lastBaseOffset = baseOffsets.back();
214
215 Value result = op.getPassThru();
216
217 // Emit a conditional access for each vector element.
218 for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
219 int64_t thisIdx[1] = {i};
220 Value condition =
221 rewriter.create<vector::ExtractOp>(loc, condMask, thisIdx);
222 Value index = rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
223 baseOffsets.back() =
224 rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
225
226 auto loadBuilder = [&](OpBuilder &b, Location loc) {
227 Value extracted;
228 if (isa<MemRefType>(Val: base.getType())) {
229 // `vector.load` does not support scalar result; emit a vector load
230 // and extract the single result instead.
231 Value load =
232 b.create<vector::LoadOp>(loc, elemVecTy, base, baseOffsets);
233 int64_t zeroIdx[1] = {0};
234 extracted = b.create<vector::ExtractOp>(loc, load, zeroIdx);
235 } else {
236 extracted = b.create<tensor::ExtractOp>(loc, base, baseOffsets);
237 }
238
239 Value newResult =
240 b.create<vector::InsertOp>(loc, extracted, result, thisIdx);
241 b.create<scf::YieldOp>(loc, newResult);
242 };
243 auto passThruBuilder = [result](OpBuilder &b, Location loc) {
244 b.create<scf::YieldOp>(loc, result);
245 };
246
247 result =
248 rewriter
249 .create<scf::IfOp>(loc, condition, /*thenBuilder=*/loadBuilder,
250 /*elseBuilder=*/passThruBuilder)
251 .getResult(0);
252 }
253
254 rewriter.replaceOp(op, result);
255 return success();
256 }
257};
258} // namespace
259
260void mlir::vector::populateVectorGatherLoweringPatterns(
261 RewritePatternSet &patterns, PatternBenefit benefit) {
262 patterns.add<FlattenGather, RemoveStrideFromGatherSource,
263 Gather1DToConditionalLoads>(arg: patterns.getContext(), args&: benefit);
264}
265

source code of mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp