1 | //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements utilities for the Linalg dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
14 | |
15 | #include "mlir/Analysis/SliceAnalysis.h" |
16 | #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" |
17 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
18 | #include "mlir/Dialect/Affine/IR/AffineValueMap.h" |
19 | #include "mlir/Dialect/Affine/LoopUtils.h" |
20 | #include "mlir/Dialect/Arith/IR/Arith.h" |
21 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
22 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
23 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
24 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
25 | #include "mlir/Dialect/SCF/IR/SCF.h" |
26 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
27 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
28 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
29 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
30 | #include "mlir/IR/AffineExpr.h" |
31 | #include "mlir/IR/AffineExprVisitor.h" |
32 | #include "mlir/IR/AffineMap.h" |
33 | #include "mlir/IR/Matchers.h" |
34 | #include "mlir/IR/OpImplementation.h" |
35 | #include "mlir/Pass/Pass.h" |
36 | #include "llvm/ADT/TypeSwitch.h" |
37 | #include "llvm/Support/Debug.h" |
38 | #include <optional> |
39 | |
40 | #define DEBUG_TYPE "linalg-utils" |
41 | |
42 | using namespace mlir; |
43 | using namespace presburger; |
44 | using namespace mlir::affine; |
45 | using namespace mlir::linalg; |
46 | using namespace mlir::scf; |
47 | |
48 | namespace { |
49 | |
50 | // Helper visitor to determine whether an AffineExpr is tiled. |
51 | // This is achieved by traversing every AffineDimExpr with position `pos` and |
52 | // checking whether the corresponding `tileSizes[pos]` is non-zero. |
53 | // This also enforces only positive coefficients occur in multiplications. |
54 | // |
55 | // Example: |
56 | // `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0] |
57 | // |
58 | struct TileCheck : public AffineExprVisitor<TileCheck> { |
59 | TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {} |
60 | |
61 | void visitDimExpr(AffineDimExpr expr) { |
62 | isTiled |= !isZeroIndex(v: tileSizes[expr.getPosition()]); |
63 | } |
64 | void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { |
65 | visit(expr: expr.getLHS()); |
66 | visit(expr: expr.getRHS()); |
67 | if (expr.getKind() == mlir::AffineExprKind::Mul) |
68 | assert(cast<AffineConstantExpr>(expr.getRHS()).getValue() > 0 && |
69 | "nonpositive multiplying coefficient" ); |
70 | } |
71 | bool isTiled = false; |
72 | ArrayRef<OpFoldResult> tileSizes; |
73 | }; |
74 | |
75 | } // namespace |
76 | |
77 | static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) { |
78 | if (!expr) |
79 | return false; |
80 | TileCheck t(tileSizes); |
81 | t.visit(expr); |
82 | return t.isTiled; |
83 | } |
84 | |
85 | // Checks whether the `map varies with respect to a non-zero `tileSize`. |
86 | static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) { |
87 | if (!map) |
88 | return false; |
89 | for (unsigned r = 0; r < map.getNumResults(); ++r) |
90 | if (isTiled(expr: map.getResult(idx: r), tileSizes)) |
91 | return true; |
92 | return false; |
93 | } |
94 | |
95 | std::optional<RegionMatcher::BinaryOpKind> |
96 | RegionMatcher::matchAsScalarBinaryOp(GenericOp op) { |
97 | auto ®ion = op.getRegion(); |
98 | if (!llvm::hasSingleElement(region)) |
99 | return std::nullopt; |
100 | |
101 | Block &block = region.front(); |
102 | if (block.getNumArguments() != 2 || |
103 | !block.getArgument(i: 0).getType().isSignlessIntOrFloat() || |
104 | !block.getArgument(i: 1).getType().isSignlessIntOrFloat()) |
105 | return std::nullopt; |
106 | |
107 | auto &ops = block.getOperations(); |
108 | if (!llvm::hasSingleElement(C: block.without_terminator())) |
109 | return std::nullopt; |
110 | |
111 | using mlir::matchers::m_Val; |
112 | auto a = m_Val(v: block.getArgument(i: 0)); |
113 | auto b = m_Val(v: block.getArgument(i: 1)); |
114 | |
115 | auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b)); |
116 | if (addPattern.match(&ops.back())) |
117 | return BinaryOpKind::IAdd; |
118 | |
119 | return std::nullopt; |
120 | } |
121 | |
122 | /// Explicit instantiation of loop nest generator for different loop types. |
123 | template struct mlir::linalg::GenerateLoopNest<scf::ForOp>; |
124 | template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>; |
125 | template struct mlir::linalg::GenerateLoopNest<AffineForOp>; |
126 | |
127 | /// Given a list of subview ranges, extract individual values for lower, upper |
128 | /// bounds and steps and put them into the corresponding vectors. |
129 | static void unpackRanges(OpBuilder &builder, Location loc, |
130 | ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs, |
131 | SmallVectorImpl<Value> &ubs, |
132 | SmallVectorImpl<Value> &steps) { |
133 | for (Range range : ranges) { |
134 | lbs.emplace_back( |
135 | Args: getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: range.offset)); |
136 | ubs.emplace_back(Args: getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: range.size)); |
137 | steps.emplace_back( |
138 | Args: getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: range.stride)); |
139 | } |
140 | } |
141 | |
142 | //===----------------------------------------------------------------------===// |
143 | // General utilities |
144 | //===----------------------------------------------------------------------===// |
145 | |
146 | namespace mlir { |
147 | namespace linalg { |
148 | |
149 | bool allIndexingsAreProjectedPermutation(LinalgOp op) { |
150 | return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) { |
151 | return m.isProjectedPermutation(/*allowZeroInResults=*/true); |
152 | }); |
153 | } |
154 | |
155 | bool hasOnlyScalarElementwiseOp(Region &r) { |
156 | if (!llvm::hasSingleElement(C&: r)) |
157 | return false; |
158 | for (Operation &op : r.front()) { |
159 | if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp, |
160 | linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) || |
161 | OpTrait::hasElementwiseMappableTraits(&op)) || |
162 | llvm::any_of(op.getResultTypes(), |
163 | [](Type type) { return !type.isIntOrIndexOrFloat(); })) |
164 | return false; |
165 | } |
166 | return true; |
167 | } |
168 | |
169 | bool isElementwise(LinalgOp op) { |
170 | if (op.getNumLoops() != op.getNumParallelLoops()) |
171 | return false; |
172 | |
173 | if (!allIndexingsAreProjectedPermutation(op)) |
174 | return false; |
175 | |
176 | // TODO: relax the restrictions on indexing map. |
177 | for (OpOperand &opOperand : op.getDpsInitsMutable()) { |
178 | if (!op.getMatchingIndexingMap(&opOperand).isPermutation()) |
179 | return false; |
180 | } |
181 | return hasOnlyScalarElementwiseOp(op->getRegion(0)); |
182 | } |
183 | |
184 | bool isParallelIterator(utils::IteratorType iteratorType) { |
185 | return iteratorType == utils::IteratorType::parallel; |
186 | } |
187 | |
188 | bool isReductionIterator(utils::IteratorType iteratorType) { |
189 | return iteratorType == utils::IteratorType::reduction; |
190 | } |
191 | |
192 | Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, |
193 | Value source, Value pad, bool nofold) { |
194 | // Exit if `source` is not defined by an ExtractSliceOp. |
195 | auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>(); |
196 | if (!sliceOp) |
197 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
198 | |
199 | // Search the `source` use-def chain for padded LinalgOps. |
200 | Value current = sliceOp.getSource(); |
201 | while (current) { |
202 | auto linalgOp = current.getDefiningOp<LinalgOp>(); |
203 | if (!linalgOp) |
204 | break; |
205 | OpResult opResult = cast<OpResult>(Val&: current); |
206 | current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get(); |
207 | } |
208 | auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr; |
209 | |
210 | // Exit if the search fails to match a tensor::PadOp at the end of the matched |
211 | // LinalgOp sequence. |
212 | if (!padOp) |
213 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
214 | |
215 | // Exit if the padded result type does not match. |
216 | if (sliceOp.getSource().getType() != type) |
217 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
218 | |
219 | // Exit if the LinalgOps are not high padded. |
220 | if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) { |
221 | return getConstantIntValue(ofr) != static_cast<int64_t>(0); |
222 | })) |
223 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
224 | |
225 | // Exit if `padOpSliceOp`, which defines the slice used by |
226 | // `padOp`, is rank-reducing. |
227 | auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>(); |
228 | if (!padOpSliceOp || |
229 | sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size()) |
230 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
231 | |
232 | // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size |
233 | // of the slice padded by `padOp`. |
234 | if (llvm::any_of( |
235 | llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()), |
236 | [](std::tuple<OpFoldResult, OpFoldResult> it) { |
237 | return !isEqualConstantIntOrValue(ofr1: std::get<0>(t&: it), ofr2: std::get<1>(t&: it)); |
238 | })) |
239 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
240 | |
241 | // Exit if the padding values do not match. |
242 | Attribute padOpPadAttr, padAttr; |
243 | Value padOpPad = padOp.getConstantPaddingValue(); |
244 | if (!padOpPad || !matchPattern(value: padOpPad, pattern: m_Constant(bind_value: &padOpPadAttr)) || |
245 | !matchPattern(value: pad, pattern: m_Constant(bind_value: &padAttr)) || padOpPadAttr != padAttr) |
246 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
247 | |
248 | // Return the padded result if the padding values and sizes match. |
249 | return sliceOp.getSource(); |
250 | } |
251 | |
252 | GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor, |
253 | Value outputTensor, |
254 | ArrayRef<int64_t> transposeVector) { |
255 | auto resultTensorType = cast<RankedTensorType>(outputTensor.getType()); |
256 | Type elementType = resultTensorType.getElementType(); |
257 | |
258 | assert(isPermutationVector(transposeVector) && |
259 | "expect transpose vector to be a permutation" ); |
260 | assert(transposeVector.size() == |
261 | static_cast<size_t>(resultTensorType.getRank()) && |
262 | "expect transpose vector size to match result tensor rank" ); |
263 | |
264 | // Compute the transpose and the indentity indexing maps. |
265 | SmallVector<AffineMap> indexingMaps = { |
266 | inversePermutation(map: AffineMap::getPermutationMap( |
267 | permutation: SmallVector<unsigned>(transposeVector.begin(), transposeVector.end()), |
268 | context: b.getContext())), |
269 | AffineMap::getMultiDimIdentityMap(numDims: transposeVector.size(), |
270 | context: b.getContext())}; |
271 | SmallVector<utils::IteratorType> iteratorTypes(transposeVector.size(), |
272 | utils::IteratorType::parallel); |
273 | |
274 | // Create a GenericOp to transpose `inputTensor` into `outputTensor`. |
275 | auto transposeOp = |
276 | b.create<GenericOp>(loc, resultTensorType, inputTensor, outputTensor, |
277 | indexingMaps, iteratorTypes); |
278 | |
279 | // Create the body of the transpose operation. |
280 | OpBuilder::InsertionGuard g(b); |
281 | Region &body = transposeOp.getRegion(); |
282 | Block *bodyBlock = b.createBlock(parent: &body, /*insertPt=*/{}, |
283 | argTypes: {elementType, elementType}, locs: {loc, loc}); |
284 | b.create<YieldOp>(loc, bodyBlock->getArgument(0)); |
285 | return transposeOp; |
286 | } |
287 | |
288 | GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) { |
289 | auto memrefTypeTo = cast<MemRefType>(to.getType()); |
290 | #ifndef NDEBUG |
291 | auto memrefTypeFrom = cast<MemRefType>(from.getType()); |
292 | assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() && |
293 | "`from` and `to` memref must have the same rank" ); |
294 | #endif // NDEBUG |
295 | |
296 | AffineMap id = |
297 | AffineMap::getMultiDimIdentityMap(numDims: memrefTypeTo.getRank(), context: b.getContext()); |
298 | SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(), |
299 | utils::IteratorType::parallel); |
300 | return b.create<linalg::GenericOp>( |
301 | loc, |
302 | /*inputs=*/from, |
303 | /*outputs=*/to, |
304 | /*indexingMaps=*/llvm::ArrayRef({id, id}), |
305 | /*iteratorTypes=*/iteratorTypes, |
306 | [](OpBuilder &b, Location loc, ValueRange args) { |
307 | b.create<linalg::YieldOp>(loc, args.front()); |
308 | }); |
309 | } |
310 | |
311 | /// Specialization to build an scf "for" nest. |
312 | template <> |
313 | void GenerateLoopNest<scf::ForOp>::doit( |
314 | OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, |
315 | ArrayRef<utils::IteratorType> iteratorTypes, |
316 | function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, |
317 | ValueRange)> |
318 | bodyBuilderFn, |
319 | ArrayRef<linalg::ProcInfo> procInfo) { |
320 | assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && |
321 | "expected as many entries for proc info as number of loops, even if " |
322 | "they are null entries" ); |
323 | SmallVector<Value> iterArgInitValues; |
324 | if (!linalgOp.hasPureBufferSemantics()) |
325 | llvm::append_range(iterArgInitValues, linalgOp.getDpsInits()); |
326 | SmallVector<Value, 4> lbs, ubs, steps; |
327 | unpackRanges(builder&: b, loc, ranges: loopRanges, lbs, ubs, steps); |
328 | LoopNest loopNest = mlir::scf::buildLoopNest( |
329 | b, loc, lbs, ubs, steps, iterArgInitValues, |
330 | [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) { |
331 | assert(iterArgs.size() == iterArgInitValues.size() && |
332 | "expect the number of output tensors and iter args to match" ); |
333 | SmallVector<Value> operandValuesToUse = linalgOp->getOperands(); |
334 | if (!iterArgs.empty()) { |
335 | operandValuesToUse = linalgOp.getDpsInputs(); |
336 | operandValuesToUse.append(in_start: iterArgs.begin(), in_end: iterArgs.end()); |
337 | } |
338 | return bodyBuilderFn(b, loc, ivs, operandValuesToUse); |
339 | }); |
340 | |
341 | if (loopNest.loops.empty() || procInfo.empty()) |
342 | return; |
343 | |
344 | // Filter out scf.for loops that were created out of parallel dimensions. |
345 | for (const auto &loop : llvm::enumerate(loopNest.loops)) { |
346 | if (procInfo[loop.index()].distributionMethod == |
347 | DistributionMethod::Cyclic) { |
348 | mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId, |
349 | procInfo[loop.index()].nprocs); |
350 | } |
351 | } |
352 | } |
353 | |
354 | /// Specialization to build affine "for" nest. |
355 | template <> |
356 | void GenerateLoopNest<AffineForOp>::doit( |
357 | OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, |
358 | ArrayRef<utils::IteratorType> iteratorTypes, |
359 | function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, |
360 | ValueRange)> |
361 | bodyBuilderFn, |
362 | ArrayRef<linalg::ProcInfo> /*procInfo*/) { |
363 | SmallVector<Value> iterArgInitValues; |
364 | if (!linalgOp.hasPureBufferSemantics()) |
365 | llvm::append_range(iterArgInitValues, linalgOp.getDpsInits()); |
366 | assert(iterArgInitValues.empty() && "unexpected AffineForOp init values" ); |
367 | SmallVector<Value, 4> lbs, ubs, steps; |
368 | unpackRanges(builder&: b, loc, ranges: loopRanges, lbs, ubs, steps); |
369 | |
370 | // Affine loops require constant steps. |
371 | SmallVector<int64_t, 4> constantSteps; |
372 | constantSteps.reserve(N: steps.size()); |
373 | for (Value v : steps) { |
374 | auto constVal = getConstantIntValue(ofr: v); |
375 | assert(constVal.has_value() && "Affine loops require constant steps" ); |
376 | constantSteps.push_back(Elt: constVal.value()); |
377 | } |
378 | |
379 | affine::buildAffineLoopNest(builder&: b, loc, lbs, ubs, steps: constantSteps, |
380 | bodyBuilderFn: [&](OpBuilder &b, Location loc, ValueRange ivs) { |
381 | bodyBuilderFn(b, loc, ivs, |
382 | linalgOp->getOperands()); |
383 | }); |
384 | } |
385 | |
386 | /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`. |
387 | void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId, |
388 | Value nprocs, Value &lb, Value &ub, |
389 | Value &step) { |
390 | AffineExpr d0, d1; |
391 | bindDims(ctx: b.getContext(), exprs&: d0, exprs&: d1); |
392 | AffineExpr s0 = getAffineSymbolExpr(position: 0, context: b.getContext()); |
393 | lb = |
394 | affine::makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step}); |
395 | step = affine::makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step}); |
396 | } |
397 | |
398 | /// Generates a loop nest consisting of scf.parallel and scf.for, depending |
399 | /// on the `iteratorTypes.` Consecutive parallel loops create a single |
400 | /// scf.parallel operation; each sequential loop creates a new scf.for |
401 | /// operation. The body of the innermost loop is populated by |
402 | /// `bodyBuilderFn` that accepts a range of induction variables for all |
403 | /// loops. `ivStorage` is used to store the partial list of induction |
404 | /// variables. |
405 | // TODO: this function can be made iterative instead. However, it |
406 | // will have at most as many recursive calls as nested loops, which rarely |
407 | // exceeds 10. |
408 | static void generateParallelLoopNest( |
409 | OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, |
410 | ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes, |
411 | ArrayRef<linalg::ProcInfo> procInfo, |
412 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, |
413 | SmallVectorImpl<Value> &ivStorage) { |
414 | assert(lbs.size() == ubs.size()); |
415 | assert(lbs.size() == steps.size()); |
416 | assert(lbs.size() == iteratorTypes.size()); |
417 | assert(procInfo.empty() || (lbs.size() == procInfo.size())); |
418 | |
419 | // If there are no (more) loops to be generated, generate the body and be |
420 | // done with it. |
421 | if (iteratorTypes.empty()) { |
422 | bodyBuilderFn(b, loc, ivStorage); |
423 | return; |
424 | } |
425 | |
426 | // If there are no outer parallel loops, generate one sequential loop and |
427 | // recurse. |
428 | if (!isParallelIterator(iteratorTypes.front())) { |
429 | LoopNest singleLoop = buildLoopNest( |
430 | builder&: b, loc, lbs: lbs.take_front(), ubs: ubs.take_front(), steps: steps.take_front(), |
431 | bodyBuilder: [&](OpBuilder &b, Location loc, ValueRange ivs) { |
432 | ivStorage.append(in_start: ivs.begin(), in_end: ivs.end()); |
433 | generateParallelLoopNest( |
434 | b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(), |
435 | iteratorTypes.drop_front(), |
436 | procInfo.empty() ? procInfo : procInfo.drop_front(), |
437 | bodyBuilderFn, ivStorage); |
438 | }); |
439 | return; |
440 | } |
441 | |
442 | unsigned nLoops = iteratorTypes.size(); |
443 | unsigned numProcessed = 0; |
444 | DistributionMethod distributionMethod = DistributionMethod::None; |
445 | if (procInfo.empty()) { |
446 | numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size(); |
447 | } else { |
448 | distributionMethod = procInfo.front().distributionMethod; |
449 | numProcessed = |
450 | nLoops - procInfo |
451 | .drop_while(Pred: [&](linalg::ProcInfo p) { |
452 | return p.distributionMethod == distributionMethod; |
453 | }) |
454 | .size(); |
455 | } |
456 | |
457 | auto remainderProcInfo = |
458 | procInfo.empty() ? procInfo : procInfo.drop_front(N: numProcessed); |
459 | switch (distributionMethod) { |
460 | case DistributionMethod::None: { |
461 | // Generate a single parallel loop-nest operation for all outermost |
462 | // parallel loops and recurse. |
463 | b.create<scf::ParallelOp>( |
464 | loc, lbs.take_front(n: numProcessed), ubs.take_front(n: numProcessed), |
465 | steps.take_front(n: numProcessed), |
466 | [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { |
467 | ivStorage.append(in_start: localIvs.begin(), in_end: localIvs.end()); |
468 | generateParallelLoopNest( |
469 | nestedBuilder, nestedLoc, lbs.drop_front(n: numProcessed), |
470 | ubs.drop_front(n: numProcessed), steps.drop_front(n: numProcessed), |
471 | iteratorTypes.drop_front(numProcessed), remainderProcInfo, |
472 | bodyBuilderFn, ivStorage); |
473 | }); |
474 | return; |
475 | } |
476 | case DistributionMethod::Cyclic: { |
477 | // Generate a single parallel loop-nest operation for all outermost |
478 | // parallel loops and recurse. |
479 | b.create<scf::ParallelOp>( |
480 | loc, lbs.take_front(n: numProcessed), ubs.take_front(n: numProcessed), |
481 | steps.take_front(n: numProcessed), |
482 | [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { |
483 | ivStorage.append(in_start: localIvs.begin(), in_end: localIvs.end()); |
484 | generateParallelLoopNest( |
485 | nestedBuilder, nestedLoc, lbs.drop_front(n: numProcessed), |
486 | ubs.drop_front(n: numProcessed), steps.drop_front(n: numProcessed), |
487 | iteratorTypes.drop_front(numProcessed), remainderProcInfo, |
488 | bodyBuilderFn, ivStorage); |
489 | }); |
490 | return; |
491 | } |
492 | case DistributionMethod::CyclicNumProcsGeNumIters: { |
493 | // Check (for the processed loops) that the iteration is in-bounds. |
494 | ArithBuilder ab(b, loc); |
495 | Value cond = ab.slt(lhs: lbs[0], rhs: ubs[0]); |
496 | for (unsigned i = 1; i < numProcessed; ++i) |
497 | cond = ab._and(lhs: cond, rhs: ab.slt(lhs: lbs[i], rhs: ubs[i])); |
498 | ivStorage.append(in_start: lbs.begin(), in_end: std::next(x: lbs.begin(), n: numProcessed)); |
499 | b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) { |
500 | generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed), |
501 | ubs.drop_front(numProcessed), |
502 | steps.drop_front(numProcessed), |
503 | iteratorTypes.drop_front(numProcessed), |
504 | remainderProcInfo, bodyBuilderFn, ivStorage); |
505 | b.create<scf::YieldOp>(loc, ValueRange{}); |
506 | }); |
507 | return; |
508 | } |
509 | case DistributionMethod::CyclicNumProcsEqNumIters: |
510 | // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed |
511 | // with inner loop generation. |
512 | ivStorage.append(in_start: lbs.begin(), in_end: std::next(x: lbs.begin(), n: numProcessed)); |
513 | generateParallelLoopNest( |
514 | b, loc, lbs.drop_front(n: numProcessed), ubs.drop_front(n: numProcessed), |
515 | steps.drop_front(n: numProcessed), iteratorTypes.drop_front(numProcessed), |
516 | remainderProcInfo, bodyBuilderFn, ivStorage); |
517 | return; |
518 | } |
519 | } |
520 | |
521 | /// Specialization for generating a mix of parallel and sequential scf loops. |
522 | template <> |
523 | void GenerateLoopNest<scf::ParallelOp>::doit( |
524 | OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, |
525 | ArrayRef<utils::IteratorType> iteratorTypes, |
526 | function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, |
527 | ValueRange)> |
528 | bodyBuilderFn, |
529 | ArrayRef<linalg::ProcInfo> procInfo) { |
530 | SmallVector<Value> iterArgInitValues; |
531 | if (!linalgOp.hasPureBufferSemantics()) |
532 | llvm::append_range(iterArgInitValues, linalgOp.getDpsInits()); |
533 | assert(iterArgInitValues.empty() && "unexpected ParallelOp init values" ); |
534 | // This function may be passed more iterator types than ranges. |
535 | assert(iteratorTypes.size() >= loopRanges.size() && |
536 | "expected iterator type for all ranges" ); |
537 | assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && |
538 | "expected proc information for all loops when present" ); |
539 | iteratorTypes = iteratorTypes.take_front(loopRanges.size()); |
540 | SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs; |
541 | unsigned numLoops = iteratorTypes.size(); |
542 | ivs.reserve(N: numLoops); |
543 | lbsStorage.reserve(N: numLoops); |
544 | ubsStorage.reserve(N: numLoops); |
545 | stepsStorage.reserve(N: numLoops); |
546 | |
547 | // Get the loop lb, ub, and step. |
548 | unpackRanges(builder&: b, loc, ranges: loopRanges, lbs&: lbsStorage, ubs&: ubsStorage, steps&: stepsStorage); |
549 | |
550 | // Modify the lb, ub, and step based on the distribution options. |
551 | for (const auto &it : llvm::enumerate(First&: procInfo)) { |
552 | if (it.value().distributionMethod != linalg::DistributionMethod::None) { |
553 | updateBoundsForCyclicDistribution( |
554 | b, loc, procId: it.value().procId, nprocs: it.value().nprocs, lb&: lbsStorage[it.index()], |
555 | ub&: ubsStorage[it.index()], step&: stepsStorage[it.index()]); |
556 | } |
557 | } |
558 | ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage); |
559 | generateParallelLoopNest( |
560 | b, loc, lbs, ubs, steps, iteratorTypes, procInfo, |
561 | [&](OpBuilder &b, Location loc, ValueRange ivs) { |
562 | bodyBuilderFn(b, loc, ivs, linalgOp->getOperands()); |
563 | }, |
564 | ivs); |
565 | |
566 | assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops" ); |
567 | } |
568 | |
569 | static Value materializeTiledShape(OpBuilder &builder, Location loc, |
570 | Value valueToTile, |
571 | const SliceParameters &sliceParams) { |
572 | auto shapedType = dyn_cast<ShapedType>(valueToTile.getType()); |
573 | auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType) |
574 | .Case([&](MemRefType) { |
575 | return builder.create<memref::SubViewOp>( |
576 | loc, valueToTile, sliceParams.offsets, |
577 | sliceParams.sizes, sliceParams.strides); |
578 | }) |
579 | .Case([&](RankedTensorType) { |
580 | return builder.create<tensor::ExtractSliceOp>( |
581 | loc, valueToTile, sliceParams.offsets, |
582 | sliceParams.sizes, sliceParams.strides); |
583 | }) |
584 | .Default([](ShapedType) -> Operation * { |
585 | llvm_unreachable("Unexpected shaped type" ); |
586 | }); |
587 | return sliceOp->getResult(0); |
588 | } |
589 | |
590 | Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, |
591 | ArrayRef<OpFoldResult> tileSizes, AffineMap map, |
592 | ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, |
593 | ArrayRef<OpFoldResult> subShapeSizes, |
594 | bool omitPartialTileCheck) { |
595 | SliceParameters sliceParams = |
596 | computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs, |
597 | ubs, subShapeSizes, omitPartialTileCheck); |
598 | return materializeTiledShape(builder, loc, valueToTile, sliceParams); |
599 | } |
600 | |
601 | SliceParameters |
602 | computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, |
603 | ArrayRef<OpFoldResult> tileSizes, AffineMap map, |
604 | ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, |
605 | ArrayRef<OpFoldResult> subShapeSizes, |
606 | bool omitPartialTileCheck) { |
607 | auto shapedType = dyn_cast<ShapedType>(valueToTile.getType()); |
608 | assert(shapedType && "only shaped types can be tiled" ); |
609 | ArrayRef<int64_t> shape = shapedType.getShape(); |
610 | int64_t rank = shapedType.getRank(); |
611 | |
612 | // Compute offsets/sizes/strides for the tile. |
613 | SliceParameters sliceParams; |
614 | sliceParams.offsets.reserve(N: rank); |
615 | sliceParams.sizes.reserve(N: rank); |
616 | sliceParams.strides.reserve(N: rank); |
617 | for (unsigned r = 0; r < rank; ++r) { |
618 | LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r); |
619 | if (!isTiled(map: map.getSubMap(resultPos: {r}), tileSizes)) { |
620 | sliceParams.offsets.push_back(builder.getIndexAttr(0)); |
621 | OpFoldResult dim = createFoldedDimOp(b&: builder, loc, val: valueToTile, dim: r); |
622 | sliceParams.sizes.push_back(Elt: dim); |
623 | sliceParams.strides.push_back(builder.getIndexAttr(1)); |
624 | LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n" ); |
625 | continue; |
626 | } |
627 | LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n" ); |
628 | |
629 | // Tiling creates a new slice at the proper index, the slice step is 1 |
630 | // (i.e. the op does not subsample, stepping occurs in the loop). |
631 | auto m = map.getSubMap(resultPos: {r}); |
632 | LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n" ); |
633 | IRRewriter rewriter(builder); |
634 | OpFoldResult offset = makeComposedFoldedAffineApply(b&: rewriter, loc, map: m, operands: lbs); |
635 | sliceParams.offsets.push_back(Elt: offset); |
636 | OpFoldResult closedIntSize = |
637 | makeComposedFoldedAffineApply(b&: rewriter, loc, map: m, operands: subShapeSizes); |
638 | // Resulting size needs to be made half open interval again. |
639 | AffineExpr s0 = getAffineSymbolExpr(position: 0, context: builder.getContext()); |
640 | OpFoldResult size = |
641 | makeComposedFoldedAffineApply(b&: rewriter, loc, expr: s0 + 1, operands: closedIntSize); |
642 | LLVM_DEBUG(llvm::dbgs() |
643 | << "computeSliceParameters: raw size: " << size << "\n" ); |
644 | LLVM_DEBUG(llvm::dbgs() |
645 | << "computeSliceParameters: new offset: " << offset << "\n" ); |
646 | sliceParams.strides.push_back(builder.getIndexAttr(1)); |
647 | |
648 | if (omitPartialTileCheck) { |
649 | // We statically know that the partial/boundary tile condition is |
650 | // unnecessary. |
651 | LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n" ); |
652 | sliceParams.sizes.push_back(Elt: size); |
653 | continue; |
654 | } |
655 | |
656 | // The size of the subview / extract_slice should be trimmed to avoid |
657 | // out-of-bounds accesses, unless: |
658 | // a. We statically know the subshape size divides the shape size evenly. |
659 | // b. The subshape size is 1. According to the way the loops are set up, |
660 | // tensors with "0" dimensions would never be constructed. |
661 | int64_t shapeSize = shape[r]; |
662 | std::optional<int64_t> sizeCst = getConstantIntValue(ofr: size); |
663 | auto hasTileSizeOne = sizeCst && *sizeCst == 1; |
664 | auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) && |
665 | ((shapeSize % *sizeCst) == 0); |
666 | if (!hasTileSizeOne && !dividesEvenly) { |
667 | LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize |
668 | << ", size: " << size |
669 | << ": make sure in bound with affine.min\n" ); |
670 | |
671 | AffineExpr dim0, dim1, dim2; |
672 | MLIRContext *context = builder.getContext(); |
673 | bindDims(ctx: context, exprs&: dim0, exprs&: dim1, exprs&: dim2); |
674 | |
675 | // Get the dimension size for this dimension. We need to first calculate |
676 | // the max index and then plus one. This is important because for |
677 | // convolution ops, we have its input window dimension's affine map of the |
678 | // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window |
679 | // dimension and `s0` is stride. Directly use the dimension size of |
680 | // output/filer window dimensions will cause incorrect calculation. |
681 | AffineMap minusOneMap = AffineMap::inferFromExprList( |
682 | exprsList: {ArrayRef<AffineExpr>{dim0 - 1}}, context) |
683 | .front(); |
684 | AffineMap plusOneMap = AffineMap::inferFromExprList( |
685 | exprsList: {ArrayRef<AffineExpr>{dim0 + 1}}, context) |
686 | .front(); |
687 | SmallVector<OpFoldResult> maxIndices = |
688 | llvm::to_vector(Range: llvm::map_range(C&: ubs, F: [&](OpFoldResult ub) { |
689 | return makeComposedFoldedAffineApply(b&: rewriter, loc, map: minusOneMap, |
690 | operands: {ub}); |
691 | })); |
692 | OpFoldResult maxIndex = |
693 | makeComposedFoldedAffineApply(b&: rewriter, loc, map: m, operands: maxIndices); |
694 | OpFoldResult d = |
695 | makeComposedFoldedAffineApply(b&: rewriter, loc, map: plusOneMap, operands: {maxIndex}); |
696 | |
697 | // Compute min(dim - offset, size) to avoid out-of-bounds accesses. |
698 | AffineMap minMap = AffineMap::inferFromExprList( |
699 | exprsList: {ArrayRef<AffineExpr>{dim1 - dim2, dim0}}, context) |
700 | .front(); |
701 | size = |
702 | makeComposedFoldedAffineMin(b&: rewriter, loc, map: minMap, operands: {size, d, offset}); |
703 | } |
704 | LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n" ); |
705 | sliceParams.sizes.push_back(Elt: size); |
706 | } |
707 | return sliceParams; |
708 | } |
709 | |
710 | SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc, |
711 | ArrayRef<OpFoldResult> ivs, |
712 | ArrayRef<OpFoldResult> tileSizes) { |
713 | SmallVector<OpFoldResult> offsets; |
714 | for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { |
715 | LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n" ); |
716 | bool isTiled = !isZeroIndex(v: tileSizes[idx]); |
717 | offsets.push_back(Elt: isTiled ? ivs[idxIvs++] : b.getIndexAttr(0)); |
718 | LLVM_DEBUG(llvm::dbgs() |
719 | << "computeTileOffsets: " << offsets.back() << "\n" ); |
720 | } |
721 | return offsets; |
722 | } |
723 | |
724 | SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc, |
725 | ArrayRef<OpFoldResult> tileSizes, |
726 | ArrayRef<OpFoldResult> sizeBounds) { |
727 | SmallVector<OpFoldResult> sizes; |
728 | for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) { |
729 | bool isTiled = !isZeroIndex(v: tileSizes[idx]); |
730 | // Before composing, we need to make range a closed interval. |
731 | OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx]; |
732 | AffineExpr d0 = getAffineDimExpr(position: 0, context: b.getContext()); |
733 | IRRewriter rewriter(b); |
734 | sizes.push_back(Elt: makeComposedFoldedAffineApply(b&: rewriter, loc, expr: d0 - 1, operands: size)); |
735 | LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n" ); |
736 | } |
737 | return sizes; |
738 | } |
739 | |
740 | SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) { |
741 | if (op.hasPureBufferSemantics()) |
742 | return {}; |
743 | return llvm::to_vector( |
744 | llvm::map_range(op.getDpsInitsMutable(), [&](OpOperand &opOperand) { |
745 | return operands[opOperand.getOperandNumber()].getType(); |
746 | })); |
747 | } |
748 | |
749 | SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc, |
750 | LinalgOp op, ValueRange operands, |
751 | ValueRange results) { |
752 | if (op.hasPureBufferSemantics()) |
753 | return {}; |
754 | SmallVector<Value> tensorResults; |
755 | tensorResults.reserve(N: results.size()); |
756 | // Insert a insert_slice for each output tensor. |
757 | unsigned resultIdx = 0; |
758 | for (OpOperand &opOperand : op.getDpsInitsMutable()) { |
759 | // TODO: use an interface/adaptor to avoid leaking position in |
760 | // `tiledOperands`. |
761 | Value outputTensor = operands[opOperand.getOperandNumber()]; |
762 | if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) { |
763 | Value inserted = builder.create<tensor::InsertSliceOp>( |
764 | loc, sliceOp.getSource().getType(), results[resultIdx], |
765 | sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(), |
766 | sliceOp.getStrides(), sliceOp.getStaticOffsets(), |
767 | sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); |
768 | tensorResults.push_back(inserted); |
769 | } else { |
770 | tensorResults.push_back(results[resultIdx]); |
771 | } |
772 | ++resultIdx; |
773 | } |
774 | return tensorResults; |
775 | } |
776 | |
777 | SmallVector<std::optional<SliceParameters>> |
778 | computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, |
779 | ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs, |
780 | ArrayRef<OpFoldResult> tileSizes, |
781 | ArrayRef<OpFoldResult> sizeBounds, |
782 | bool omitPartialTileCheck) { |
783 | assert(ivs.size() == static_cast<size_t>(llvm::count_if( |
784 | llvm::make_range(tileSizes.begin(), tileSizes.end()), |
785 | [](OpFoldResult v) { return !isZeroIndex(v); })) && |
786 | "expected as many ivs as non-zero sizes" ); |
787 | |
788 | // Construct (potentially temporary) mins and maxes on which to apply maps |
789 | // that define tile subshapes. |
790 | SmallVector<OpFoldResult> lbs = |
791 | computeTileOffsets(b&: builder, loc, ivs, tileSizes); |
792 | SmallVector<OpFoldResult> subShapeSizes = |
793 | computeTileSizes(b&: builder, loc, tileSizes, sizeBounds); |
794 | |
795 | assert(static_cast<int64_t>(valuesToTile.size()) <= |
796 | linalgOp->getNumOperands() && |
797 | "more value to tile than operands." ); |
798 | SmallVector<std::optional<SliceParameters>> allSliceParams; |
799 | allSliceParams.reserve(N: valuesToTile.size()); |
800 | for (auto [opOperand, val] : |
801 | llvm::zip(linalgOp->getOpOperands(), valuesToTile)) { |
802 | Value shapedOp = val; |
803 | LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp); |
804 | AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand); |
805 | // Use `opOperand` as is if it is not tiled and not an output tensor. Having |
806 | // an extract/insert slice pair for all output tensors simplifies follow up |
807 | // transformations such as padding and bufferization since the |
808 | // extract/insert slice pairs make the accessed iteration argument |
809 | // subdomains explicit. |
810 | |
811 | Type operandType = opOperand.get().getType(); |
812 | if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) && |
813 | linalgOp.isDpsInit(&opOperand))) { |
814 | allSliceParams.push_back(std::nullopt); |
815 | LLVM_DEBUG(llvm::dbgs() |
816 | << ": not tiled: use shape: " << operandType << "\n" ); |
817 | continue; |
818 | } |
819 | LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n" ); |
820 | |
821 | allSliceParams.push_back(computeSliceParameters( |
822 | builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes, |
823 | omitPartialTileCheck)); |
824 | } |
825 | |
826 | return allSliceParams; |
827 | } |
828 | |
829 | SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc, |
830 | LinalgOp linalgOp, ValueRange valuesToTile, |
831 | ArrayRef<OpFoldResult> ivs, |
832 | ArrayRef<OpFoldResult> tileSizes, |
833 | ArrayRef<OpFoldResult> sizeBounds, |
834 | bool omitPartialTileCheck) { |
835 | SmallVector<std::optional<SliceParameters>> allSliceParameter = |
836 | computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs, |
837 | tileSizes, sizeBounds, omitPartialTileCheck); |
838 | SmallVector<Value> tiledShapes; |
839 | for (auto item : llvm::zip(valuesToTile, allSliceParameter)) { |
840 | Value valueToTile = std::get<0>(item); |
841 | std::optional<SliceParameters> sliceParams = std::get<1>(item); |
842 | tiledShapes.push_back( |
843 | sliceParams.has_value() |
844 | ? materializeTiledShape(builder, loc, valueToTile, *sliceParams) |
845 | : valueToTile); |
846 | } |
847 | return tiledShapes; |
848 | } |
849 | |
850 | void offsetIndices(OpBuilder &b, LinalgOp linalgOp, |
851 | ArrayRef<OpFoldResult> offsets) { |
852 | IRRewriter rewriter(b); |
853 | offsetIndices(rewriter, linalgOp, offsets); |
854 | } |
855 | |
856 | void offsetIndices(RewriterBase &b, LinalgOp linalgOp, |
857 | ArrayRef<OpFoldResult> offsets) { |
858 | if (!linalgOp.hasIndexSemantics()) |
859 | return; |
860 | |
861 | for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) { |
862 | if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()]) |
863 | continue; |
864 | OpBuilder::InsertionGuard guard(b); |
865 | b.setInsertionPointAfter(indexOp); |
866 | AffineExpr index, offset; |
867 | bindDims(b.getContext(), index, offset); |
868 | OpFoldResult applied = makeComposedFoldedAffineApply( |
869 | b, indexOp.getLoc(), index + offset, |
870 | {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]}); |
871 | Value materialized = |
872 | getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied); |
873 | b.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) { |
874 | return use.getOwner() != materialized.getDefiningOp(); |
875 | }); |
876 | } |
877 | } |
878 | |
879 | /// Get the reassociation maps to fold the result of a extract_slice (or source |
880 | /// of a insert_slice) operation with given offsets, and sizes to its |
881 | /// rank-reduced version. This is only done for the cases where the size is 1 |
882 | /// and offset is 0. Strictly speaking the offset 0 is not required in general, |
883 | /// but non-zero offsets are not handled by SPIR-V backend at this point (and |
884 | /// potentially cannot be handled). |
885 | std::optional<SmallVector<ReassociationIndices>> |
886 | getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) { |
887 | SmallVector<ReassociationIndices> reassociation; |
888 | ReassociationIndices curr; |
889 | for (const auto &it : llvm::enumerate(First&: mixedSizes)) { |
890 | auto dim = it.index(); |
891 | auto size = it.value(); |
892 | curr.push_back(Elt: dim); |
893 | auto attr = llvm::dyn_cast_if_present<Attribute>(Val&: size); |
894 | if (attr && cast<IntegerAttr>(attr).getInt() == 1) |
895 | continue; |
896 | reassociation.emplace_back(Args: ReassociationIndices{}); |
897 | std::swap(LHS&: reassociation.back(), RHS&: curr); |
898 | } |
899 | // When the reassociations are not empty, then fold the remaining |
900 | // unit-dimensions into the last dimension. If the reassociations so far is |
901 | // empty, then leave it emtpy. This will fold everything to a rank-0 tensor. |
902 | if (!curr.empty() && !reassociation.empty()) |
903 | reassociation.back().append(in_start: curr.begin(), in_end: curr.end()); |
904 | return reassociation; |
905 | } |
906 | |
907 | } // namespace linalg |
908 | } // namespace mlir |
909 | |