1//===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Linalg/Utils/Utils.h"
14
15#include "mlir/Analysis/SliceAnalysis.h"
16#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17#include "mlir/Dialect/Affine/IR/AffineOps.h"
18#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19#include "mlir/Dialect/Affine/LoopUtils.h"
20#include "mlir/Dialect/Arith/IR/Arith.h"
21#include "mlir/Dialect/Arith/Utils/Utils.h"
22#include "mlir/Dialect/Func/IR/FuncOps.h"
23#include "mlir/Dialect/Linalg/IR/Linalg.h"
24#include "mlir/Dialect/MemRef/IR/MemRef.h"
25#include "mlir/Dialect/SCF/IR/SCF.h"
26#include "mlir/Dialect/Tensor/IR/Tensor.h"
27#include "mlir/Dialect/Tensor/Utils/Utils.h"
28#include "mlir/Dialect/Utils/IndexingUtils.h"
29#include "mlir/Dialect/Utils/StaticValueUtils.h"
30#include "mlir/IR/AffineExpr.h"
31#include "mlir/IR/AffineExprVisitor.h"
32#include "mlir/IR/AffineMap.h"
33#include "mlir/IR/Matchers.h"
34#include "mlir/IR/OpImplementation.h"
35#include "mlir/Pass/Pass.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/Debug.h"
38#include <optional>
39
40#define DEBUG_TYPE "linalg-utils"
41
42using namespace mlir;
43using namespace presburger;
44using namespace mlir::affine;
45using namespace mlir::linalg;
46using namespace mlir::scf;
47
48namespace {
49
50// Helper visitor to determine whether an AffineExpr is tiled.
51// This is achieved by traversing every AffineDimExpr with position `pos` and
52// checking whether the corresponding `tileSizes[pos]` is non-zero.
53// This also enforces only positive coefficients occur in multiplications.
54//
55// Example:
56// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
57//
58struct TileCheck : public AffineExprVisitor<TileCheck> {
59 TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
60
61 void visitDimExpr(AffineDimExpr expr) {
62 isTiled |= !isZeroInteger(v: tileSizes[expr.getPosition()]);
63 }
64 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
65 visit(expr: expr.getLHS());
66 visit(expr: expr.getRHS());
67 if (expr.getKind() == mlir::AffineExprKind::Mul)
68 assert(cast<AffineConstantExpr>(expr.getRHS()).getValue() > 0 &&
69 "nonpositive multiplying coefficient");
70 }
71 bool isTiled = false;
72 ArrayRef<OpFoldResult> tileSizes;
73};
74
75} // namespace
76
77static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) {
78 if (!expr)
79 return false;
80 TileCheck t(tileSizes);
81 t.visit(expr);
82 return t.isTiled;
83}
84
85// Checks whether the `map varies with respect to a non-zero `tileSize`.
86static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) {
87 if (!map)
88 return false;
89 for (unsigned r = 0; r < map.getNumResults(); ++r)
90 if (isTiled(expr: map.getResult(idx: r), tileSizes))
91 return true;
92 return false;
93}
94
95std::optional<RegionMatcher::BinaryOpKind>
96RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
97 auto &region = op.getRegion();
98 if (!llvm::hasSingleElement(region))
99 return std::nullopt;
100
101 Block &block = region.front();
102 if (block.getNumArguments() != 2 ||
103 !block.getArgument(i: 0).getType().isSignlessIntOrFloat() ||
104 !block.getArgument(i: 1).getType().isSignlessIntOrFloat())
105 return std::nullopt;
106
107 auto &ops = block.getOperations();
108 if (!llvm::hasSingleElement(C: block.without_terminator()))
109 return std::nullopt;
110
111 using mlir::matchers::m_Val;
112 auto a = m_Val(v: block.getArgument(i: 0));
113 auto b = m_Val(v: block.getArgument(i: 1));
114
115 auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b));
116 if (addPattern.match(&ops.back()))
117 return BinaryOpKind::IAdd;
118
119 return std::nullopt;
120}
121
122/// Explicit instantiation of loop nest generator for different loop types.
123template struct mlir::linalg::GenerateLoopNest<scf::ForOp>;
124template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>;
125template struct mlir::linalg::GenerateLoopNest<AffineForOp>;
126
127/// Given a list of subview ranges, extract individual values for lower, upper
128/// bounds and steps and put them into the corresponding vectors.
129static void unpackRanges(OpBuilder &builder, Location loc,
130 ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
131 SmallVectorImpl<Value> &ubs,
132 SmallVectorImpl<Value> &steps) {
133 for (Range range : ranges) {
134 lbs.emplace_back(
135 Args: getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: range.offset));
136 ubs.emplace_back(Args: getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: range.size));
137 steps.emplace_back(
138 Args: getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: range.stride));
139 }
140}
141
142//===----------------------------------------------------------------------===//
143// General utilities
144//===----------------------------------------------------------------------===//
145//
146/// The permutation can be obtained from two permutations:
147/// a) Compute the permutation vector to move the last `numPackedDims` into
148/// the `innerPosDims` of a shape of rank `rank`.
149/// b) Compute the permutation vector to move outer dims if the
150/// `outerPerm` parameter is not empty.
151/// Apply (b) permutation on (a) permutation to get the final permutation.
152static SmallVector<int64_t>
153computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos,
154 ArrayRef<int64_t> &outerPerm,
155 PackingMetadata &packingMetadata) {
156 int64_t numPackedDims = innerDimsPos.size();
157 auto lastDims =
158 llvm::to_vector(Range: llvm::seq<int64_t>(Begin: rank - numPackedDims, End: rank));
159 packingMetadata = computePackingMetadata(packedRank: rank, innerDimPos: innerDimsPos);
160 SmallVector<int64_t> innerPositionsPerm =
161 computePermutationVector(permSize: rank, positions: lastDims, desiredPositions: packingMetadata.insertPositions);
162
163 SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
164 if (!outerPerm.empty())
165 applyPermutationToVector(inVec&: outerPos, permutation: outerPerm);
166 SmallVector<int64_t> outerPositionPerm =
167 computePermutationVector(permSize: rank, positions: packingMetadata.outerPositions, desiredPositions: outerPos);
168
169 SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm;
170 applyPermutationToVector(inVec&: packInverseDestPermutation, permutation: outerPositionPerm);
171 return packInverseDestPermutation;
172}
173
174namespace mlir {
175namespace linalg {
176
177SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp) {
178
179 PackingMetadata pMetadata;
180 int64_t packedRank = packOp.getDestType().getRank();
181 ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
182 ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
183 SmallVector<int64_t> packInvDestPerm =
184 computePackUnPackPerm(rank: packedRank, innerDimsPos&: innerDimPos, outerPerm, packingMetadata&: pMetadata);
185 return packInvDestPerm;
186}
187
188SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp) {
189 PackingMetadata metadata;
190 return getUnPackInverseSrcPerm(unpackOp, metadata);
191}
192
193SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp,
194 PackingMetadata &metadata) {
195 int64_t unpackRank = unpackOp.getSourceType().getRank();
196 ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
197 ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
198 SmallVector<int64_t> unpackInvSrcPerm =
199 computePackUnPackPerm(rank: unpackRank, innerDimsPos&: innerDimPos, outerPerm, packingMetadata&: metadata);
200 return unpackInvSrcPerm;
201}
202
203bool allIndexingsAreProjectedPermutation(LinalgOp op) {
204 return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) {
205 return m.isProjectedPermutation(/*allowZeroInResults=*/true);
206 });
207}
208
209bool hasOnlyScalarElementwiseOp(Region &r) {
210 if (!llvm::hasSingleElement(C&: r))
211 return false;
212 for (Operation &op : r.front()) {
213 if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp,
214 linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) ||
215 OpTrait::hasElementwiseMappableTraits(&op)) ||
216 llvm::any_of(op.getResultTypes(),
217 [](Type type) { return !type.isIntOrIndexOrFloat(); }))
218 return false;
219 }
220 return true;
221}
222
223bool isElementwise(LinalgOp op) {
224 if (op.getNumLoops() != op.getNumParallelLoops())
225 return false;
226
227 if (!allIndexingsAreProjectedPermutation(op))
228 return false;
229
230 // TODO: relax the restrictions on indexing map.
231 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
232 if (!op.getMatchingIndexingMap(&opOperand).isPermutation())
233 return false;
234 }
235 return hasOnlyScalarElementwiseOp(op->getRegion(0));
236}
237
238bool isParallelIterator(utils::IteratorType iteratorType) {
239 return iteratorType == utils::IteratorType::parallel;
240}
241
242bool isReductionIterator(utils::IteratorType iteratorType) {
243 return iteratorType == utils::IteratorType::reduction;
244}
245
246Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
247 Value source, Value pad, bool nofold) {
248 // Exit if `source` is not defined by an ExtractSliceOp.
249 auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>();
250 if (!sliceOp)
251 return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
252
253 // Search the `source` use-def chain for padded LinalgOps.
254 Value current = sliceOp.getSource();
255 while (current) {
256 auto linalgOp = current.getDefiningOp<LinalgOp>();
257 if (!linalgOp)
258 break;
259 OpResult opResult = cast<OpResult>(Val&: current);
260 current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get();
261 }
262 auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr;
263
264 // Exit if the search fails to match a tensor::PadOp at the end of the matched
265 // LinalgOp sequence.
266 if (!padOp)
267 return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
268
269 // Exit if the padded result type does not match.
270 if (sliceOp.getSource().getType() != type)
271 return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
272
273 // Exit if the LinalgOps are not high padded.
274 if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) {
275 return getConstantIntValue(ofr) != static_cast<int64_t>(0);
276 }))
277 return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
278
279 // Exit if `padOpSliceOp`, which defines the slice used by
280 // `padOp`, is rank-reducing.
281 auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
282 if (!padOpSliceOp ||
283 sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size())
284 return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
285
286 // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size
287 // of the slice padded by `padOp`.
288 if (llvm::any_of(
289 llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()),
290 [](std::tuple<OpFoldResult, OpFoldResult> it) {
291 return !isEqualConstantIntOrValue(ofr1: std::get<0>(t&: it), ofr2: std::get<1>(t&: it));
292 }))
293 return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
294
295 // Exit if the padding values do not match.
296 Attribute padOpPadAttr, padAttr;
297 Value padOpPad = padOp.getConstantPaddingValue();
298 if (!padOpPad || !matchPattern(value: padOpPad, pattern: m_Constant(bind_value: &padOpPadAttr)) ||
299 !matchPattern(value: pad, pattern: m_Constant(bind_value: &padAttr)) || padOpPadAttr != padAttr)
300 return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
301
302 // Return the padded result if the padding values and sizes match.
303 return sliceOp.getSource();
304}
305
306GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
307 auto memrefTypeTo = cast<MemRefType>(to.getType());
308#ifndef NDEBUG
309 auto memrefTypeFrom = cast<MemRefType>(from.getType());
310 assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() &&
311 "`from` and `to` memref must have the same rank");
312#endif // NDEBUG
313
314 AffineMap id =
315 AffineMap::getMultiDimIdentityMap(numDims: memrefTypeTo.getRank(), context: b.getContext());
316 SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(),
317 utils::IteratorType::parallel);
318 return b.create<linalg::GenericOp>(
319 loc,
320 /*inputs=*/from,
321 /*outputs=*/to,
322 /*indexingMaps=*/llvm::ArrayRef({id, id}),
323 /*iteratorTypes=*/iteratorTypes,
324 [](OpBuilder &b, Location loc, ValueRange args) {
325 b.create<linalg::YieldOp>(loc, args.front());
326 });
327}
328
329/// Specialization to build an scf "for" nest.
330template <>
331void GenerateLoopNest<scf::ForOp>::doit(
332 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
333 ArrayRef<utils::IteratorType> iteratorTypes,
334 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
335 ValueRange)>
336 bodyBuilderFn,
337 ArrayRef<linalg::ProcInfo> procInfo) {
338 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
339 "expected as many entries for proc info as number of loops, even if "
340 "they are null entries");
341 SmallVector<Value> iterArgInitValues;
342 if (!linalgOp.hasPureBufferSemantics())
343 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
344 SmallVector<Value, 4> lbs, ubs, steps;
345 unpackRanges(builder&: b, loc, ranges: loopRanges, lbs, ubs, steps);
346 LoopNest loopNest = mlir::scf::buildLoopNest(
347 b, loc, lbs, ubs, steps, iterArgInitValues,
348 [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) {
349 assert(iterArgs.size() == iterArgInitValues.size() &&
350 "expect the number of output tensors and iter args to match");
351 SmallVector<Value> operandValuesToUse = linalgOp->getOperands();
352 if (!iterArgs.empty()) {
353 operandValuesToUse = linalgOp.getDpsInputs();
354 operandValuesToUse.append(in_start: iterArgs.begin(), in_end: iterArgs.end());
355 }
356 return bodyBuilderFn(b, loc, ivs, operandValuesToUse);
357 });
358
359 if (loopNest.loops.empty() || procInfo.empty())
360 return;
361
362 // Filter out scf.for loops that were created out of parallel dimensions.
363 for (const auto &loop : llvm::enumerate(loopNest.loops)) {
364 if (procInfo[loop.index()].distributionMethod ==
365 DistributionMethod::Cyclic) {
366 mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId,
367 procInfo[loop.index()].nprocs);
368 }
369 }
370}
371
372/// Specialization to build affine "for" nest.
373template <>
374void GenerateLoopNest<AffineForOp>::doit(
375 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
376 ArrayRef<utils::IteratorType> iteratorTypes,
377 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
378 ValueRange)>
379 bodyBuilderFn,
380 ArrayRef<linalg::ProcInfo> /*procInfo*/) {
381 SmallVector<Value> iterArgInitValues;
382 if (!linalgOp.hasPureBufferSemantics())
383 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
384 assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
385 SmallVector<Value, 4> lbs, ubs, steps;
386 unpackRanges(builder&: b, loc, ranges: loopRanges, lbs, ubs, steps);
387
388 // Affine loops require constant steps.
389 SmallVector<int64_t, 4> constantSteps;
390 constantSteps.reserve(N: steps.size());
391 for (Value v : steps) {
392 auto constVal = getConstantIntValue(ofr: v);
393 assert(constVal.has_value() && "Affine loops require constant steps");
394 constantSteps.push_back(Elt: constVal.value());
395 }
396
397 affine::buildAffineLoopNest(builder&: b, loc, lbs, ubs, steps: constantSteps,
398 bodyBuilderFn: [&](OpBuilder &b, Location loc, ValueRange ivs) {
399 bodyBuilderFn(b, loc, ivs,
400 linalgOp->getOperands());
401 });
402}
403
404/// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
405void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId,
406 Value nprocs, Value &lb, Value &ub,
407 Value &step) {
408 AffineExpr d0, d1;
409 bindDims(ctx: b.getContext(), exprs&: d0, exprs&: d1);
410 AffineExpr s0 = getAffineSymbolExpr(position: 0, context: b.getContext());
411 lb =
412 affine::makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step});
413 step = affine::makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step});
414}
415
416/// Generates a loop nest consisting of scf.parallel and scf.for, depending
417/// on the `iteratorTypes.` Consecutive parallel loops create a single
418/// scf.parallel operation; each sequential loop creates a new scf.for
419/// operation. The body of the innermost loop is populated by
420/// `bodyBuilderFn` that accepts a range of induction variables for all
421/// loops. `ivStorage` is used to store the partial list of induction
422/// variables.
423// TODO: this function can be made iterative instead. However, it
424// will have at most as many recursive calls as nested loops, which rarely
425// exceeds 10.
426static void generateParallelLoopNest(
427 OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
428 ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes,
429 ArrayRef<linalg::ProcInfo> procInfo,
430 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
431 SmallVectorImpl<Value> &ivStorage) {
432 assert(lbs.size() == ubs.size());
433 assert(lbs.size() == steps.size());
434 assert(lbs.size() == iteratorTypes.size());
435 assert(procInfo.empty() || (lbs.size() == procInfo.size()));
436
437 // If there are no (more) loops to be generated, generate the body and be
438 // done with it.
439 if (iteratorTypes.empty()) {
440 bodyBuilderFn(b, loc, ivStorage);
441 return;
442 }
443
444 // If there are no outer parallel loops, generate one sequential loop and
445 // recurse.
446 if (!isParallelIterator(iteratorTypes.front())) {
447 LoopNest singleLoop = buildLoopNest(
448 builder&: b, loc, lbs: lbs.take_front(), ubs: ubs.take_front(), steps: steps.take_front(),
449 bodyBuilder: [&](OpBuilder &b, Location loc, ValueRange ivs) {
450 ivStorage.append(in_start: ivs.begin(), in_end: ivs.end());
451 generateParallelLoopNest(
452 b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(),
453 iteratorTypes.drop_front(),
454 procInfo.empty() ? procInfo : procInfo.drop_front(),
455 bodyBuilderFn, ivStorage);
456 });
457 return;
458 }
459
460 unsigned nLoops = iteratorTypes.size();
461 unsigned numProcessed = 0;
462 DistributionMethod distributionMethod = DistributionMethod::None;
463 if (procInfo.empty()) {
464 numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size();
465 } else {
466 distributionMethod = procInfo.front().distributionMethod;
467 numProcessed =
468 nLoops - procInfo
469 .drop_while(Pred: [&](linalg::ProcInfo p) {
470 return p.distributionMethod == distributionMethod;
471 })
472 .size();
473 }
474
475 auto remainderProcInfo =
476 procInfo.empty() ? procInfo : procInfo.drop_front(N: numProcessed);
477 switch (distributionMethod) {
478 case DistributionMethod::None: {
479 // Generate a single parallel loop-nest operation for all outermost
480 // parallel loops and recurse.
481 b.create<scf::ParallelOp>(
482 loc, lbs.take_front(n: numProcessed), ubs.take_front(n: numProcessed),
483 steps.take_front(n: numProcessed),
484 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
485 ivStorage.append(in_start: localIvs.begin(), in_end: localIvs.end());
486 generateParallelLoopNest(
487 nestedBuilder, nestedLoc, lbs.drop_front(n: numProcessed),
488 ubs.drop_front(n: numProcessed), steps.drop_front(n: numProcessed),
489 iteratorTypes.drop_front(numProcessed), remainderProcInfo,
490 bodyBuilderFn, ivStorage);
491 });
492 return;
493 }
494 case DistributionMethod::Cyclic: {
495 // Generate a single parallel loop-nest operation for all outermost
496 // parallel loops and recurse.
497 b.create<scf::ParallelOp>(
498 loc, lbs.take_front(n: numProcessed), ubs.take_front(n: numProcessed),
499 steps.take_front(n: numProcessed),
500 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
501 ivStorage.append(in_start: localIvs.begin(), in_end: localIvs.end());
502 generateParallelLoopNest(
503 nestedBuilder, nestedLoc, lbs.drop_front(n: numProcessed),
504 ubs.drop_front(n: numProcessed), steps.drop_front(n: numProcessed),
505 iteratorTypes.drop_front(numProcessed), remainderProcInfo,
506 bodyBuilderFn, ivStorage);
507 });
508 return;
509 }
510 case DistributionMethod::CyclicNumProcsGeNumIters: {
511 // Check (for the processed loops) that the iteration is in-bounds.
512 ArithBuilder ab(b, loc);
513 Value cond = ab.slt(lhs: lbs[0], rhs: ubs[0]);
514 for (unsigned i = 1; i < numProcessed; ++i)
515 cond = ab._and(lhs: cond, rhs: ab.slt(lhs: lbs[i], rhs: ubs[i]));
516 ivStorage.append(in_start: lbs.begin(), in_end: std::next(x: lbs.begin(), n: numProcessed));
517 b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) {
518 generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed),
519 ubs.drop_front(numProcessed),
520 steps.drop_front(numProcessed),
521 iteratorTypes.drop_front(numProcessed),
522 remainderProcInfo, bodyBuilderFn, ivStorage);
523 b.create<scf::YieldOp>(loc, ValueRange{});
524 });
525 return;
526 }
527 case DistributionMethod::CyclicNumProcsEqNumIters:
528 // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed
529 // with inner loop generation.
530 ivStorage.append(in_start: lbs.begin(), in_end: std::next(x: lbs.begin(), n: numProcessed));
531 generateParallelLoopNest(
532 b, loc, lbs.drop_front(n: numProcessed), ubs.drop_front(n: numProcessed),
533 steps.drop_front(n: numProcessed), iteratorTypes.drop_front(numProcessed),
534 remainderProcInfo, bodyBuilderFn, ivStorage);
535 return;
536 }
537}
538
539/// Specialization for generating a mix of parallel and sequential scf loops.
540template <>
541void GenerateLoopNest<scf::ParallelOp>::doit(
542 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
543 ArrayRef<utils::IteratorType> iteratorTypes,
544 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
545 ValueRange)>
546 bodyBuilderFn,
547 ArrayRef<linalg::ProcInfo> procInfo) {
548 SmallVector<Value> iterArgInitValues;
549 if (!linalgOp.hasPureBufferSemantics())
550 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
551 assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
552 // This function may be passed more iterator types than ranges.
553 assert(iteratorTypes.size() >= loopRanges.size() &&
554 "expected iterator type for all ranges");
555 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
556 "expected proc information for all loops when present");
557 iteratorTypes = iteratorTypes.take_front(loopRanges.size());
558 SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs;
559 unsigned numLoops = iteratorTypes.size();
560 ivs.reserve(N: numLoops);
561 lbsStorage.reserve(N: numLoops);
562 ubsStorage.reserve(N: numLoops);
563 stepsStorage.reserve(N: numLoops);
564
565 // Get the loop lb, ub, and step.
566 unpackRanges(builder&: b, loc, ranges: loopRanges, lbs&: lbsStorage, ubs&: ubsStorage, steps&: stepsStorage);
567
568 // Modify the lb, ub, and step based on the distribution options.
569 for (const auto &it : llvm::enumerate(First&: procInfo)) {
570 if (it.value().distributionMethod != linalg::DistributionMethod::None) {
571 updateBoundsForCyclicDistribution(
572 b, loc, procId: it.value().procId, nprocs: it.value().nprocs, lb&: lbsStorage[it.index()],
573 ub&: ubsStorage[it.index()], step&: stepsStorage[it.index()]);
574 }
575 }
576 ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
577 generateParallelLoopNest(
578 b, loc, lbs, ubs, steps, iteratorTypes, procInfo,
579 [&](OpBuilder &b, Location loc, ValueRange ivs) {
580 bodyBuilderFn(b, loc, ivs, linalgOp->getOperands());
581 },
582 ivs);
583
584 assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
585}
586
587static Operation *materializeTiledShape(OpBuilder &builder, Location loc,
588 Value valueToTile,
589 const SliceParameters &sliceParams) {
590 auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
591 auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
592 .Case([&](MemRefType) {
593 return builder.create<memref::SubViewOp>(
594 loc, valueToTile, sliceParams.offsets,
595 sliceParams.sizes, sliceParams.strides);
596 })
597 .Case([&](RankedTensorType) {
598 return builder.create<tensor::ExtractSliceOp>(
599 loc, valueToTile, sliceParams.offsets,
600 sliceParams.sizes, sliceParams.strides);
601 })
602 .Default([](ShapedType) -> Operation * {
603 llvm_unreachable("Unexpected shaped type");
604 });
605 return sliceOp;
606}
607
608Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
609 ArrayRef<OpFoldResult> tileSizes, AffineMap map,
610 ArrayRef<OpFoldResult> lbs,
611 ArrayRef<OpFoldResult> ubs,
612 ArrayRef<OpFoldResult> subShapeSizes,
613 bool omitPartialTileCheck) {
614 SliceParameters sliceParams =
615 computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
616 ubs, subShapeSizes, omitPartialTileCheck);
617 return materializeTiledShape(builder, loc, valueToTile, sliceParams);
618}
619
620SliceParameters
621computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
622 ArrayRef<OpFoldResult> tileSizes, AffineMap map,
623 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
624 ArrayRef<OpFoldResult> subShapeSizes,
625 bool omitPartialTileCheck) {
626 auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
627 assert(shapedType && "only shaped types can be tiled");
628 ArrayRef<int64_t> shape = shapedType.getShape();
629 int64_t rank = shapedType.getRank();
630
631 // Compute offsets/sizes/strides for the tile.
632 SliceParameters sliceParams;
633 sliceParams.offsets.reserve(N: rank);
634 sliceParams.sizes.reserve(N: rank);
635 sliceParams.strides.reserve(N: rank);
636 for (unsigned r = 0; r < rank; ++r) {
637 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r);
638 if (!isTiled(map: map.getSubMap(resultPos: {r}), tileSizes)) {
639 sliceParams.offsets.push_back(builder.getIndexAttr(0));
640 OpFoldResult dim = createFoldedDimOp(b&: builder, loc, val: valueToTile, dim: r);
641 sliceParams.sizes.push_back(Elt: dim);
642 sliceParams.strides.push_back(builder.getIndexAttr(1));
643 LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
644 continue;
645 }
646 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n");
647
648 // Tiling creates a new slice at the proper index, the slice step is 1
649 // (i.e. the op does not subsample, stepping occurs in the loop).
650 auto m = map.getSubMap(resultPos: {r});
651 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n");
652 IRRewriter rewriter(builder);
653 // The offset of the slice is m(lbs) - m(0).
654 SmallVector<Attribute> zeros(lbs.size(), rewriter.getIndexAttr(0));
655 SmallVector<Attribute> mAtZero;
656 [[maybe_unused]] auto res = m.constantFold(operandConstants: zeros, results&: mAtZero);
657 assert(succeeded(res) && "affine_map must be evaluatable (not symbols)");
658 int64_t mAtZeroInt =
659 cast<IntegerAttr>(mAtZero[0]).getValue().getSExtValue();
660 OpFoldResult offset = makeComposedFoldedAffineApply(
661 b&: rewriter, loc, expr: m.getResult(idx: 0) - mAtZeroInt, operands: lbs);
662 sliceParams.offsets.push_back(Elt: offset);
663
664 OpFoldResult closedIntSize =
665 makeComposedFoldedAffineApply(b&: rewriter, loc, map: m, operands: subShapeSizes);
666 // Resulting size needs to be made half open interval again.
667 AffineExpr s0 = getAffineSymbolExpr(position: 0, context: builder.getContext());
668 OpFoldResult size =
669 makeComposedFoldedAffineApply(b&: rewriter, loc, expr: s0 + 1, operands: closedIntSize);
670 LLVM_DEBUG(llvm::dbgs()
671 << "computeSliceParameters: raw size: " << size << "\n");
672 LLVM_DEBUG(llvm::dbgs()
673 << "computeSliceParameters: new offset: " << offset << "\n");
674 sliceParams.strides.push_back(builder.getIndexAttr(1));
675
676 if (omitPartialTileCheck) {
677 // We statically know that the partial/boundary tile condition is
678 // unnecessary.
679 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
680 sliceParams.sizes.push_back(Elt: size);
681 continue;
682 }
683
684 // The size of the subview / extract_slice should be trimmed to avoid
685 // out-of-bounds accesses, unless:
686 // a. We statically know the subshape size divides the shape size evenly.
687 // b. The subshape size is 1. According to the way the loops are set up,
688 // tensors with "0" dimensions would never be constructed.
689 int64_t shapeSize = shape[r];
690 std::optional<int64_t> sizeCst = getConstantIntValue(ofr: size);
691 auto hasTileSizeOne = sizeCst && *sizeCst == 1;
692 auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) &&
693 ((shapeSize % *sizeCst) == 0);
694 if (!hasTileSizeOne && !dividesEvenly) {
695 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize
696 << ", size: " << size
697 << ": make sure in bound with affine.min\n");
698
699 AffineExpr dim0, dim1, dim2;
700 MLIRContext *context = builder.getContext();
701 bindDims(ctx: context, exprs&: dim0, exprs&: dim1, exprs&: dim2);
702
703 // Get the dimension size for this dimension. We need to first calculate
704 // the max index and then plus one. This is important because for
705 // convolution ops, we have its input window dimension's affine map of the
706 // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window
707 // dimension and `s0` is stride. Directly use the dimension size of
708 // output/filer window dimensions will cause incorrect calculation.
709 AffineMap minusOneMap = AffineMap::inferFromExprList(
710 exprsList: {ArrayRef<AffineExpr>{dim0 - 1}}, context)
711 .front();
712 AffineMap plusOneMap = AffineMap::inferFromExprList(
713 exprsList: {ArrayRef<AffineExpr>{dim0 + 1}}, context)
714 .front();
715 SmallVector<OpFoldResult> maxIndices =
716 llvm::to_vector(Range: llvm::map_range(C&: ubs, F: [&](OpFoldResult ub) {
717 return makeComposedFoldedAffineApply(b&: rewriter, loc, map: minusOneMap,
718 operands: {ub});
719 }));
720 OpFoldResult maxIndex =
721 makeComposedFoldedAffineApply(b&: rewriter, loc, map: m, operands: maxIndices);
722 OpFoldResult d =
723 makeComposedFoldedAffineApply(b&: rewriter, loc, map: plusOneMap, operands: {maxIndex});
724
725 // Compute min(dim - offset, size) to avoid out-of-bounds accesses.
726 AffineMap minMap = AffineMap::inferFromExprList(
727 exprsList: {ArrayRef<AffineExpr>{dim1 - dim2, dim0}}, context)
728 .front();
729 size =
730 makeComposedFoldedAffineMin(b&: rewriter, loc, map: minMap, operands: {size, d, offset});
731 }
732 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
733 sliceParams.sizes.push_back(Elt: size);
734 }
735 return sliceParams;
736}
737
738SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
739 ArrayRef<OpFoldResult> ivs,
740 ArrayRef<OpFoldResult> tileSizes) {
741 SmallVector<OpFoldResult> offsets;
742 for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
743 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
744 bool isTiled = !isZeroInteger(v: tileSizes[idx]);
745 offsets.push_back(Elt: isTiled ? ivs[idxIvs++] : b.getIndexAttr(0));
746 LLVM_DEBUG(llvm::dbgs()
747 << "computeTileOffsets: " << offsets.back() << "\n");
748 }
749 return offsets;
750}
751
752SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
753 ArrayRef<OpFoldResult> tileSizes,
754 ArrayRef<OpFoldResult> sizeBounds) {
755 SmallVector<OpFoldResult> sizes;
756 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
757 bool isTiled = !isZeroInteger(v: tileSizes[idx]);
758 // Before composing, we need to make range a closed interval.
759 OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx];
760 AffineExpr d0 = getAffineDimExpr(position: 0, context: b.getContext());
761 IRRewriter rewriter(b);
762 sizes.push_back(Elt: makeComposedFoldedAffineApply(b&: rewriter, loc, expr: d0 - 1, operands: size));
763 LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n");
764 }
765 return sizes;
766}
767
768SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) {
769 if (op.hasPureBufferSemantics())
770 return {};
771 return llvm::to_vector(
772 llvm::map_range(op.getDpsInitsMutable(), [&](OpOperand &opOperand) {
773 return operands[opOperand.getOperandNumber()].getType();
774 }));
775}
776
777SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
778 LinalgOp op, ValueRange operands,
779 ValueRange results) {
780 if (op.hasPureBufferSemantics())
781 return {};
782 SmallVector<Value> tensorResults;
783 tensorResults.reserve(N: results.size());
784 // Insert a insert_slice for each output tensor.
785 unsigned resultIdx = 0;
786 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
787 // TODO: use an interface/adaptor to avoid leaking position in
788 // `tiledOperands`.
789 Value outputTensor = operands[opOperand.getOperandNumber()];
790 if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
791 Value inserted = builder.create<tensor::InsertSliceOp>(
792 loc, sliceOp.getSource().getType(), results[resultIdx],
793 sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(),
794 sliceOp.getStrides(), sliceOp.getStaticOffsets(),
795 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
796 tensorResults.push_back(inserted);
797 } else {
798 tensorResults.push_back(results[resultIdx]);
799 }
800 ++resultIdx;
801 }
802 return tensorResults;
803}
804
805SmallVector<std::optional<SliceParameters>>
806computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
807 ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs,
808 ArrayRef<OpFoldResult> tileSizes,
809 ArrayRef<OpFoldResult> sizeBounds,
810 bool omitPartialTileCheck) {
811 assert(ivs.size() == static_cast<size_t>(llvm::count_if(
812 llvm::make_range(tileSizes.begin(), tileSizes.end()),
813 [](OpFoldResult v) { return !isZeroInteger(v); })) &&
814 "expected as many ivs as non-zero sizes");
815
816 // Construct (potentially temporary) mins and maxes on which to apply maps
817 // that define tile subshapes.
818 SmallVector<OpFoldResult> lbs =
819 computeTileOffsets(b&: builder, loc, ivs, tileSizes);
820 SmallVector<OpFoldResult> subShapeSizes =
821 computeTileSizes(b&: builder, loc, tileSizes, sizeBounds);
822
823 assert(static_cast<int64_t>(valuesToTile.size()) <=
824 linalgOp->getNumOperands() &&
825 "more value to tile than operands.");
826 SmallVector<std::optional<SliceParameters>> allSliceParams;
827 allSliceParams.reserve(N: valuesToTile.size());
828 for (auto [opOperand, val] :
829 llvm::zip(linalgOp->getOpOperands(), valuesToTile)) {
830 Value shapedOp = val;
831 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
832 AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
833 // Use `opOperand` as is if it is not tiled and not an output tensor. Having
834 // an extract/insert slice pair for all output tensors simplifies follow up
835 // transformations such as padding and bufferization since the
836 // extract/insert slice pairs make the accessed iteration argument
837 // subdomains explicit.
838
839 Type operandType = opOperand.get().getType();
840 if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) &&
841 linalgOp.isDpsInit(&opOperand))) {
842 allSliceParams.push_back(std::nullopt);
843 LLVM_DEBUG(llvm::dbgs()
844 << ": not tiled: use shape: " << operandType << "\n");
845 continue;
846 }
847 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
848
849 allSliceParams.push_back(computeSliceParameters(
850 builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
851 omitPartialTileCheck));
852 }
853
854 return allSliceParams;
855}
856
857SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc,
858 LinalgOp linalgOp, ValueRange valuesToTile,
859 ArrayRef<OpFoldResult> ivs,
860 ArrayRef<OpFoldResult> tileSizes,
861 ArrayRef<OpFoldResult> sizeBounds,
862 bool omitPartialTileCheck) {
863 SmallVector<std::optional<SliceParameters>> allSliceParameter =
864 computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs,
865 tileSizes, sizeBounds, omitPartialTileCheck);
866 SmallVector<Value> tiledShapes;
867 for (auto item : llvm::zip(valuesToTile, allSliceParameter)) {
868 Value valueToTile = std::get<0>(item);
869 std::optional<SliceParameters> sliceParams = std::get<1>(item);
870 tiledShapes.push_back(
871 sliceParams.has_value()
872 ? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
873 ->getResult(0)
874 : valueToTile);
875 }
876 return tiledShapes;
877}
878
879void offsetIndices(OpBuilder &b, LinalgOp linalgOp,
880 ArrayRef<OpFoldResult> offsets) {
881 IRRewriter rewriter(b);
882 offsetIndices(rewriter, linalgOp, offsets);
883}
884
885void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
886 ArrayRef<OpFoldResult> offsets) {
887 if (!linalgOp.hasIndexSemantics())
888 return;
889
890 for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) {
891 if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()])
892 continue;
893 OpBuilder::InsertionGuard guard(b);
894 b.setInsertionPointAfter(indexOp);
895 AffineExpr index, offset;
896 bindDims(b.getContext(), index, offset);
897 OpFoldResult applied = makeComposedFoldedAffineApply(
898 b, indexOp.getLoc(), index + offset,
899 {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
900 Value materialized =
901 getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
902 b.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) {
903 return use.getOwner() != materialized.getDefiningOp();
904 });
905 }
906}
907
908/// Get the reassociation maps to fold the result of a extract_slice (or source
909/// of a insert_slice) operation with given offsets, and sizes to its
910/// rank-reduced version. This is only done for the cases where the size is 1
911/// and offset is 0. Strictly speaking the offset 0 is not required in general,
912/// but non-zero offsets are not handled by SPIR-V backend at this point (and
913/// potentially cannot be handled).
914std::optional<SmallVector<ReassociationIndices>>
915getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
916 SmallVector<ReassociationIndices> reassociation;
917 ReassociationIndices curr;
918 for (const auto &it : llvm::enumerate(First&: mixedSizes)) {
919 auto dim = it.index();
920 auto size = it.value();
921 curr.push_back(Elt: dim);
922 auto attr = llvm::dyn_cast_if_present<Attribute>(Val&: size);
923 if (attr && cast<IntegerAttr>(attr).getInt() == 1)
924 continue;
925 reassociation.emplace_back(Args: ReassociationIndices{});
926 std::swap(LHS&: reassociation.back(), RHS&: curr);
927 }
928 // When the reassociations are not empty, then fold the remaining
929 // unit-dimensions into the last dimension. If the reassociations so far is
930 // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
931 if (!curr.empty() && !reassociation.empty())
932 reassociation.back().append(in_start: curr.begin(), in_end: curr.end());
933 return reassociation;
934}
935
936} // namespace linalg
937} // namespace mlir
938

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Linalg/Utils/Utils.cpp