1//===- VectorUtils.cpp - MLIR Utilities for VectorOps ------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utility methods for working with the Vector dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
14
15#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
16#include "mlir/Dialect/Affine/IR/AffineOps.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/MemRef/IR/MemRef.h"
20#include "mlir/Dialect/Tensor/IR/Tensor.h"
21#include "mlir/Dialect/Utils/IndexingUtils.h"
22#include "mlir/Dialect/Vector/IR/VectorOps.h"
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/IntegerSet.h"
25#include "mlir/IR/Operation.h"
26#include "mlir/IR/TypeUtilities.h"
27#include "mlir/Support/LLVM.h"
28
29#include "llvm/ADT/DenseSet.h"
30#include "llvm/Support/InterleavedRange.h"
31
32#define DEBUG_TYPE "vector-utils"
33
34#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
35#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
36
37using namespace mlir;
38
39/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
40/// the type of `source`.
41Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
42 int64_t dim) {
43 if (isa<UnrankedMemRefType, MemRefType>(source.getType()))
44 return b.createOrFold<memref::DimOp>(loc, source, dim);
45 if (isa<UnrankedTensorType, RankedTensorType>(source.getType()))
46 return b.createOrFold<tensor::DimOp>(loc, source, dim);
47 llvm_unreachable("Expected MemRefType or TensorType");
48}
49
50/// Given the n-D transpose pattern 'transp', return true if 'dim0' and 'dim1'
51/// should be transposed with each other within the context of their 2D
52/// transposition slice.
53///
54/// Example 1: dim0 = 0, dim1 = 2, transp = [2, 1, 0]
55/// Return true: dim0 and dim1 are transposed within the context of their 2D
56/// transposition slice ([1, 0]).
57///
58/// Example 2: dim0 = 0, dim1 = 1, transp = [2, 1, 0]
59/// Return true: dim0 and dim1 are transposed within the context of their 2D
60/// transposition slice ([1, 0]). Paradoxically, note how dim1 (1) is *not*
61/// transposed within the full context of the transposition.
62///
63/// Example 3: dim0 = 0, dim1 = 1, transp = [2, 0, 1]
64/// Return false: dim0 and dim1 are *not* transposed within the context of
65/// their 2D transposition slice ([0, 1]). Paradoxically, note how dim0 (0)
66/// and dim1 (1) are transposed within the full context of the of the
67/// transposition.
68static bool areDimsTransposedIn2DSlice(int64_t dim0, int64_t dim1,
69 ArrayRef<int64_t> transp) {
70 // Perform a linear scan along the dimensions of the transposed pattern. If
71 // dim0 is found first, dim0 and dim1 are not transposed within the context of
72 // their 2D slice. Otherwise, 'dim1' is found first and they are transposed.
73 for (int64_t permDim : transp) {
74 if (permDim == dim0)
75 return false;
76 if (permDim == dim1)
77 return true;
78 }
79
80 llvm_unreachable("Ill-formed transpose pattern");
81}
82
83FailureOr<std::pair<int, int>>
84mlir::vector::isTranspose2DSlice(vector::TransposeOp op) {
85 VectorType srcType = op.getSourceVectorType();
86 SmallVector<int64_t> srcGtOneDims;
87 for (auto [index, size] : llvm::enumerate(srcType.getShape()))
88 if (size > 1)
89 srcGtOneDims.push_back(index);
90
91 if (srcGtOneDims.size() != 2)
92 return failure();
93
94 // Check whether the two source vector dimensions that are greater than one
95 // must be transposed with each other so that we can apply one of the 2-D
96 // transpose pattens. Otherwise, these patterns are not applicable.
97 if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1],
98 op.getPermutation()))
99 return failure();
100
101 return std::pair<int, int>(srcGtOneDims[0], srcGtOneDims[1]);
102}
103
104/// Constructs a permutation map from memref indices to vector dimension.
105///
106/// The implementation uses the knowledge of the mapping of enclosing loop to
107/// vector dimension. `enclosingLoopToVectorDim` carries this information as a
108/// map with:
109/// - keys representing "vectorized enclosing loops";
110/// - values representing the corresponding vector dimension.
111/// The algorithm traverses "vectorized enclosing loops" and extracts the
112/// at-most-one MemRef index that is invariant along said loop. This index is
113/// guaranteed to be at most one by construction: otherwise the MemRef is not
114/// vectorizable.
115/// If this invariant index is found, it is added to the permutation_map at the
116/// proper vector dimension.
117/// If no index is found to be invariant, 0 is added to the permutation_map and
118/// corresponds to a vector broadcast along that dimension.
119///
120/// Returns an empty AffineMap if `enclosingLoopToVectorDim` is empty,
121/// signalling that no permutation map can be constructed given
122/// `enclosingLoopToVectorDim`.
123///
124/// Examples can be found in the documentation of `makePermutationMap`, in the
125/// header file.
126static AffineMap makePermutationMap(
127 ArrayRef<Value> indices,
128 const DenseMap<Operation *, unsigned> &enclosingLoopToVectorDim) {
129 if (enclosingLoopToVectorDim.empty())
130 return AffineMap();
131 MLIRContext *context =
132 enclosingLoopToVectorDim.begin()->getFirst()->getContext();
133 SmallVector<AffineExpr> perm(enclosingLoopToVectorDim.size(),
134 getAffineConstantExpr(constant: 0, context));
135
136 for (auto kvp : enclosingLoopToVectorDim) {
137 assert(kvp.second < perm.size());
138 auto invariants = affine::getInvariantAccesses(
139 iv: cast<affine::AffineForOp>(kvp.first).getInductionVar(), indices);
140 unsigned numIndices = indices.size();
141 unsigned countInvariantIndices = 0;
142 for (unsigned dim = 0; dim < numIndices; ++dim) {
143 if (!invariants.count(indices[dim])) {
144 assert(perm[kvp.second] == getAffineConstantExpr(0, context) &&
145 "permutationMap already has an entry along dim");
146 perm[kvp.second] = getAffineDimExpr(position: dim, context);
147 } else {
148 ++countInvariantIndices;
149 }
150 }
151 assert((countInvariantIndices == numIndices ||
152 countInvariantIndices == numIndices - 1) &&
153 "Vectorization prerequisite violated: at most 1 index may be "
154 "invariant wrt a vectorized loop");
155 (void)countInvariantIndices;
156 }
157 return AffineMap::get(dimCount: indices.size(), symbolCount: 0, results: perm, context);
158}
159
160/// Implementation detail that walks up the parents and records the ones with
161/// the specified type.
162/// TODO: could also be implemented as a collect parents followed by a
163/// filter and made available outside this file.
164template <typename T>
165static SetVector<Operation *> getParentsOfType(Block *block) {
166 SetVector<Operation *> res;
167 auto *current = block->getParentOp();
168 while (current) {
169 if ([[maybe_unused]] auto typedParent = dyn_cast<T>(current)) {
170 assert(res.count(current) == 0 && "Already inserted");
171 res.insert(X: current);
172 }
173 current = current->getParentOp();
174 }
175 return res;
176}
177
178/// Returns the enclosing AffineForOp, from closest to farthest.
179static SetVector<Operation *> getEnclosingforOps(Block *block) {
180 return getParentsOfType<affine::AffineForOp>(block);
181}
182
183AffineMap mlir::makePermutationMap(
184 Block *insertPoint, ArrayRef<Value> indices,
185 const DenseMap<Operation *, unsigned> &loopToVectorDim) {
186 DenseMap<Operation *, unsigned> enclosingLoopToVectorDim;
187 auto enclosingLoops = getEnclosingforOps(block: insertPoint);
188 for (auto *forInst : enclosingLoops) {
189 auto it = loopToVectorDim.find(Val: forInst);
190 if (it != loopToVectorDim.end()) {
191 enclosingLoopToVectorDim.insert(KV: *it);
192 }
193 }
194 return ::makePermutationMap(indices, enclosingLoopToVectorDim);
195}
196
197AffineMap mlir::makePermutationMap(
198 Operation *op, ArrayRef<Value> indices,
199 const DenseMap<Operation *, unsigned> &loopToVectorDim) {
200 return makePermutationMap(insertPoint: op->getBlock(), indices, loopToVectorDim);
201}
202
203bool matcher::operatesOnSuperVectorsOf(Operation &op,
204 VectorType subVectorType) {
205 // First, extract the vector type and distinguish between:
206 // a. ops that *must* lower a super-vector (i.e. vector.transfer_read,
207 // vector.transfer_write); and
208 // b. ops that *may* lower a super-vector (all other ops).
209 // The ops that *may* lower a super-vector only do so if the super-vector to
210 // sub-vector ratio exists. The ops that *must* lower a super-vector are
211 // explicitly checked for this property.
212 /// TODO: there should be a single function for all ops to do this so we
213 /// do not have to special case. Maybe a trait, or just a method, unclear atm.
214 bool mustDivide = false;
215 (void)mustDivide;
216 VectorType superVectorType;
217 if (auto transfer = dyn_cast<VectorTransferOpInterface>(op)) {
218 superVectorType = transfer.getVectorType();
219 mustDivide = true;
220 } else if (op.getNumResults() == 0) {
221 if (!isa<func::ReturnOp>(op)) {
222 op.emitError(message: "NYI: assuming only return operations can have 0 "
223 " results at this point");
224 }
225 return false;
226 } else if (op.getNumResults() == 1) {
227 if (auto v = dyn_cast<VectorType>(op.getResult(0).getType())) {
228 superVectorType = v;
229 } else {
230 // Not a vector type.
231 return false;
232 }
233 } else {
234 // Not a vector.transfer and has more than 1 result, fail hard for now to
235 // wake us up when something changes.
236 op.emitError(message: "NYI: operation has more than 1 result");
237 return false;
238 }
239
240 // Get the ratio.
241 auto ratio =
242 computeShapeRatio(superVectorType.getShape(), subVectorType.getShape());
243
244 // Sanity check.
245 assert((ratio || !mustDivide) &&
246 "vector.transfer operation in which super-vector size is not an"
247 " integer multiple of sub-vector size");
248
249 // This catches cases that are not strictly necessary to have multiplicity but
250 // still aren't divisible by the sub-vector shape.
251 // This could be useful information if we wanted to reshape at the level of
252 // the vector type (but we would have to look at the compute and distinguish
253 // between parallel, reduction and possibly other cases.
254 return ratio.has_value();
255}
256
257bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
258 if (vectorType.isScalable())
259 return false;
260
261 ArrayRef<int64_t> vectorShape = vectorType.getShape();
262 auto vecRank = vectorType.getRank();
263
264 if (!memrefType.areTrailingDimsContiguous(vecRank))
265 return false;
266
267 // Extract the trailing dims and strides of the input memref
268 auto memrefShape = memrefType.getShape().take_back(vecRank);
269
270 // Compare the dims of `vectorType` against `memrefType` (in reverse).
271 // In the most basic case, all dims will match.
272 auto firstNonMatchingDim =
273 std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
274 memrefShape.rbegin(), memrefShape.rend());
275 if (firstNonMatchingDim.first == vectorShape.rend())
276 return true;
277
278 // One non-matching dim is still fine, however the remaining leading dims of
279 // `vectorType` need to be 1.
280 SmallVector<int64_t> leadingDims(++firstNonMatchingDim.first,
281 vectorShape.rend());
282
283 return llvm::all_of(Range&: leadingDims, P: [](auto x) { return x == 1; });
284}
285
286std::optional<StaticTileOffsetRange>
287vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
288 if (vType.getRank() <= targetRank)
289 return {};
290 // Attempt to unroll until targetRank or the first scalable dimension (which
291 // cannot be unrolled).
292 auto shapeToUnroll = vType.getShape().drop_back(targetRank);
293 auto scalableDimsToUnroll = vType.getScalableDims().drop_back(targetRank);
294 auto it = llvm::find(scalableDimsToUnroll, true);
295 auto firstScalableDim = it - scalableDimsToUnroll.begin();
296 if (firstScalableDim == 0)
297 return {};
298 // All scalable dimensions should be removed now.
299 scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim);
300 assert(!llvm::is_contained(scalableDimsToUnroll, true) &&
301 "unexpected leading scalable dimension");
302 // Create an unroll iterator for leading dimensions.
303 shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
304 return StaticTileOffsetRange(shapeToUnroll, /*unrollStep=*/1);
305}
306
307SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
308 Operation *xfer,
309 RewriterBase &rewriter) {
310 auto loc = xfer->getLoc();
311
312 Value base = TypeSwitch<Operation *, Value>(xfer)
313 .Case<vector::TransferReadOp>(
314 caseFn: [&](auto readOp) { return readOp.getBase(); })
315 .Case<vector::TransferWriteOp>(
316 caseFn: [&](auto writeOp) { return writeOp.getOperand(1); });
317
318 SmallVector<OpFoldResult> mixedSourceDims =
319 hasTensorSemantics ? tensor::getMixedSizes(builder&: rewriter, loc, value: base)
320 : memref::getMixedSizes(builder&: rewriter, loc, value: base);
321 return mixedSourceDims;
322}
323
324bool vector::isLinearizableVector(VectorType type) {
325 return (type.getRank() > 1) && (type.getNumScalableDims() <= 1);
326}
327
328Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
329 Value source,
330 ArrayRef<int64_t> inputVectorSizes,
331 Value padValue,
332 bool useInBoundsInsteadOfMasking) {
333 assert(llvm::none_of(inputVectorSizes,
334 [](int64_t s) { return s == ShapedType::kDynamic; }) &&
335 "invalid input vector sizes");
336 auto sourceShapedType = cast<ShapedType>(source.getType());
337 auto sourceShape = sourceShapedType.getShape();
338 assert(sourceShape.size() == inputVectorSizes.size() &&
339 "expected same ranks.");
340 auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
341 assert(padValue.getType() == sourceShapedType.getElementType() &&
342 "expected same pad element type to match source element type");
343 int64_t readRank = inputVectorSizes.size();
344 auto zero = builder.create<arith::ConstantIndexOp>(location: loc, args: 0);
345 SmallVector<bool> inBoundsVal(readRank, true);
346
347 if (useInBoundsInsteadOfMasking) {
348 // Update the inBounds attribute.
349 // FIXME: This computation is too weak - it ignores the read indices.
350 for (unsigned i = 0; i < readRank; i++)
351 inBoundsVal[i] = (sourceShape[i] == inputVectorSizes[i]) &&
352 !ShapedType::isDynamic(sourceShape[i]);
353 }
354 auto transferReadOp = builder.create<vector::TransferReadOp>(
355 loc,
356 /*vectorType=*/vectorType,
357 /*source=*/source,
358 /*indices=*/SmallVector<Value>(readRank, zero),
359 /*padding=*/padValue,
360 /*inBounds=*/inBoundsVal);
361
362 if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
363 return transferReadOp;
364 SmallVector<OpFoldResult> mixedSourceDims =
365 tensor::getMixedSizes(builder, loc, value: source);
366
367 auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
368 Value mask =
369 builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
370 return mlir::vector::maskOperation(builder, maskableOp: transferReadOp, mask)
371 ->getResult(0);
372}
373
374LogicalResult
375vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
376 ArrayRef<int64_t> inputVectorSizes) {
377 LDBG("Iteration space static sizes:" << llvm::interleaved(shape));
378
379 if (inputVectorSizes.size() != shape.size()) {
380 LDBG("Input vector sizes don't match the number of loops");
381 return failure();
382 }
383 if (ShapedType::isDynamicShape(inputVectorSizes)) {
384 LDBG("Input vector sizes can't have dynamic dimensions");
385 return failure();
386 }
387 if (!llvm::all_of(Range: llvm::zip(t&: shape, u&: inputVectorSizes),
388 P: [](std::tuple<int64_t, int64_t> sizePair) {
389 int64_t staticSize = std::get<0>(t&: sizePair);
390 int64_t inputSize = std::get<1>(t&: sizePair);
391 return ShapedType::isDynamic(staticSize) ||
392 staticSize <= inputSize;
393 })) {
394 LDBG("Input vector sizes must be greater than or equal to iteration space "
395 "static sizes");
396 return failure();
397 }
398 return success();
399}
400

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp