1//===- LowerVectorGather.cpp - Lower 'vector.gather' operation ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.gather' operation.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Arith/Utils/Utils.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Dialect/Tensor/IR/Tensor.h"
19#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
21#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
22#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/Location.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/IR/TypeUtilities.h"
27
28#define DEBUG_TYPE "vector-broadcast-lowering"
29
30using namespace mlir;
31using namespace mlir::vector;
32
33namespace {
34/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
35/// outermost dimension. For example:
36/// ```
37/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
38/// ... into vector<2x3xf32>
39///
40/// ==>
41///
42/// %0 = arith.constant dense<0.0> : vector<2x3xf32>
43/// %g0 = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ...
44/// %1 = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32>
45/// %g1 = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ...
46/// %g = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32>
47/// ```
48///
49/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
50///
51/// Supports vector types with a fixed leading dimension.
52struct UnrollGather : OpRewritePattern<vector::GatherOp> {
53 using OpRewritePattern::OpRewritePattern;
54
55 LogicalResult matchAndRewrite(vector::GatherOp op,
56 PatternRewriter &rewriter) const override {
57 VectorType resultTy = op.getType();
58 if (resultTy.getRank() < 2)
59 return rewriter.notifyMatchFailure(arg&: op, msg: "already 1-D");
60
61 // Unrolling doesn't take vscale into account. Pattern is disabled for
62 // vectors with leading scalable dim(s).
63 if (resultTy.getScalableDims().front())
64 return rewriter.notifyMatchFailure(arg&: op, msg: "cannot unroll scalable dim");
65
66 Location loc = op.getLoc();
67 Value indexVec = op.getIndexVec();
68 Value maskVec = op.getMask();
69 Value passThruVec = op.getPassThru();
70
71 Value result = rewriter.create<arith::ConstantOp>(
72 location: loc, args&: resultTy, args: rewriter.getZeroAttr(type: resultTy));
73
74 VectorType subTy = VectorType::Builder(resultTy).dropDim(pos: 0);
75
76 for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
77 int64_t thisIdx[1] = {i};
78
79 Value indexSubVec =
80 rewriter.create<vector::ExtractOp>(location: loc, args&: indexVec, args&: thisIdx);
81 Value maskSubVec =
82 rewriter.create<vector::ExtractOp>(location: loc, args&: maskVec, args&: thisIdx);
83 Value passThruSubVec =
84 rewriter.create<vector::ExtractOp>(location: loc, args&: passThruVec, args&: thisIdx);
85 Value subGather = rewriter.create<vector::GatherOp>(
86 location: loc, args&: subTy, args: op.getBase(), args: op.getIndices(), args&: indexSubVec, args&: maskSubVec,
87 args&: passThruSubVec);
88 result =
89 rewriter.create<vector::InsertOp>(location: loc, args&: subGather, args&: result, args&: thisIdx);
90 }
91
92 rewriter.replaceOp(op, newValues: result);
93 return success();
94 }
95};
96
97/// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
98/// MemRef with updated indices that model the strided access.
99///
100/// ```mlir
101/// %subview = memref.subview %M (...)
102/// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
103/// %gather = vector.gather %subview[%idxs] (...)
104/// : memref<100xf32, strided<[3]>>
105/// ```
106/// ==>
107/// ```mlir
108/// %collapse_shape = memref.collapse_shape %M (...)
109/// : memref<100x3xf32> into memref<300xf32>
110/// %new_idxs = arith.muli %idxs, %c3 : vector<4xindex>
111/// %gather = vector.gather %collapse_shape[%new_idxs] (...)
112/// : memref<300xf32> (...)
113/// ```
114///
115/// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
116/// but should be fairly straightforward to extend beyond that.
117struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
118 using OpRewritePattern::OpRewritePattern;
119
120 LogicalResult matchAndRewrite(vector::GatherOp op,
121 PatternRewriter &rewriter) const override {
122 Value base = op.getBase();
123
124 // TODO: Strided accesses might be coming from other ops as well
125 auto subview = base.getDefiningOp<memref::SubViewOp>();
126 if (!subview)
127 return failure();
128
129 auto sourceType = subview.getSource().getType();
130
131 // TODO: Allow ranks > 2.
132 if (sourceType.getRank() != 2)
133 return failure();
134
135 // Get strides
136 auto layout = subview.getResult().getType().getLayout();
137 auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(Val&: layout);
138 if (!stridedLayoutAttr)
139 return failure();
140
141 // TODO: Allow the access to be strided in multiple dimensions.
142 if (stridedLayoutAttr.getStrides().size() != 1)
143 return failure();
144
145 int64_t srcTrailingDim = sourceType.getShape().back();
146
147 // Assume that the stride matches the trailing dimension of the source
148 // memref.
149 // TODO: Relax this assumption.
150 if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
151 return failure();
152
153 // 1. Collapse the input memref so that it's "flat".
154 SmallVector<ReassociationIndices> reassoc = {{0, 1}};
155 Value collapsed = rewriter.create<memref::CollapseShapeOp>(
156 location: op.getLoc(), args: subview.getSource(), args&: reassoc);
157
158 // 2. Generate new gather indices that will model the
159 // strided access.
160 IntegerAttr stride = rewriter.getIndexAttr(value: srcTrailingDim);
161 VectorType vType = op.getIndexVec().getType();
162 Value mulCst = rewriter.create<arith::ConstantOp>(
163 location: op.getLoc(), args&: vType, args: DenseElementsAttr::get(type: vType, values: stride));
164
165 Value newIdxs =
166 rewriter.create<arith::MulIOp>(location: op.getLoc(), args: op.getIndexVec(), args&: mulCst);
167
168 // 3. Create an updated gather op with the collapsed input memref and the
169 // updated indices.
170 Value newGather = rewriter.create<vector::GatherOp>(
171 location: op.getLoc(), args: op.getResult().getType(), args&: collapsed, args: op.getIndices(),
172 args&: newIdxs, args: op.getMask(), args: op.getPassThru());
173 rewriter.replaceOp(op, newValues: newGather);
174
175 return success();
176 }
177};
178
179/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
180/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
181/// loads/extracts are made conditional using `scf.if` ops.
182struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
183 using OpRewritePattern::OpRewritePattern;
184
185 LogicalResult matchAndRewrite(vector::GatherOp op,
186 PatternRewriter &rewriter) const override {
187 VectorType resultTy = op.getType();
188 if (resultTy.getRank() != 1)
189 return rewriter.notifyMatchFailure(arg&: op, msg: "unsupported rank");
190
191 if (resultTy.isScalable())
192 return rewriter.notifyMatchFailure(arg&: op, msg: "not a fixed-width vector");
193
194 Location loc = op.getLoc();
195 Type elemTy = resultTy.getElementType();
196 // Vector type with a single element. Used to generate `vector.loads`.
197 VectorType elemVecTy = VectorType::get(shape: {1}, elementType: elemTy);
198
199 Value condMask = op.getMask();
200 Value base = op.getBase();
201
202 // vector.load requires the most minor memref dim to have unit stride
203 // (unless reading exactly 1 element)
204 if (auto memType = dyn_cast<MemRefType>(Val: base.getType())) {
205 if (auto stridesAttr =
206 dyn_cast_if_present<StridedLayoutAttr>(Val: memType.getLayout())) {
207 if (stridesAttr.getStrides().back() != 1 &&
208 resultTy.getNumElements() != 1)
209 return failure();
210 }
211 }
212
213 Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
214 location: loc, args: op.getIndexVectorType().clone(elementType: rewriter.getIndexType()),
215 args: op.getIndexVec());
216 auto baseOffsets = llvm::to_vector(Range: op.getIndices());
217 Value lastBaseOffset = baseOffsets.back();
218
219 Value result = op.getPassThru();
220
221 // Emit a conditional access for each vector element.
222 for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
223 int64_t thisIdx[1] = {i};
224 Value condition =
225 rewriter.create<vector::ExtractOp>(location: loc, args&: condMask, args&: thisIdx);
226 Value index = rewriter.create<vector::ExtractOp>(location: loc, args&: indexVec, args&: thisIdx);
227 baseOffsets.back() =
228 rewriter.createOrFold<arith::AddIOp>(location: loc, args&: lastBaseOffset, args&: index);
229
230 auto loadBuilder = [&](OpBuilder &b, Location loc) {
231 Value extracted;
232 if (isa<MemRefType>(Val: base.getType())) {
233 // `vector.load` does not support scalar result; emit a vector load
234 // and extract the single result instead.
235 Value load =
236 b.create<vector::LoadOp>(location: loc, args&: elemVecTy, args&: base, args&: baseOffsets);
237 int64_t zeroIdx[1] = {0};
238 extracted = b.create<vector::ExtractOp>(location: loc, args&: load, args&: zeroIdx);
239 } else {
240 extracted = b.create<tensor::ExtractOp>(location: loc, args&: base, args&: baseOffsets);
241 }
242
243 Value newResult =
244 b.create<vector::InsertOp>(location: loc, args&: extracted, args&: result, args&: thisIdx);
245 b.create<scf::YieldOp>(location: loc, args&: newResult);
246 };
247 auto passThruBuilder = [result](OpBuilder &b, Location loc) {
248 b.create<scf::YieldOp>(location: loc, args: result);
249 };
250
251 result =
252 rewriter
253 .create<scf::IfOp>(location: loc, args&: condition, /*thenBuilder=*/args&: loadBuilder,
254 /*elseBuilder=*/args&: passThruBuilder)
255 .getResult(i: 0);
256 }
257
258 rewriter.replaceOp(op, newValues: result);
259 return success();
260 }
261};
262} // namespace
263
264void mlir::vector::populateVectorGatherLoweringPatterns(
265 RewritePatternSet &patterns, PatternBenefit benefit) {
266 patterns.add<UnrollGather>(arg: patterns.getContext(), args&: benefit);
267}
268
269void mlir::vector::populateVectorGatherToConditionalLoadPatterns(
270 RewritePatternSet &patterns, PatternBenefit benefit) {
271 patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
272 arg: patterns.getContext(), args&: benefit);
273}
274

source code of mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp