1//===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Tiling pass.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Affine/LoopUtils.h"
15#include "mlir/Dialect/Arith/Utils/Utils.h"
16#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
17#include "mlir/Dialect/Linalg/IR/Linalg.h"
18#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19#include "mlir/Dialect/MemRef/IR/MemRef.h"
20#include "mlir/Dialect/SCF/Transforms/Transforms.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/Dialect/Utils/IndexingUtils.h"
23#include "mlir/Dialect/Utils/StaticValueUtils.h"
24#include "mlir/IR/AffineExpr.h"
25#include "mlir/IR/AffineMap.h"
26#include "mlir/IR/ValueRange.h"
27#include "mlir/Transforms/FoldUtils.h"
28#include "llvm/ADT/STLExtras.h"
29#include <utility>
30
31namespace mlir {
32#define GEN_PASS_DEF_LINALGTILINGPASS
33#include "mlir/Dialect/Linalg/Passes.h.inc"
34} // namespace mlir
35
36using namespace mlir;
37using namespace mlir::affine;
38using namespace mlir::linalg;
39using namespace mlir::scf;
40
41#define DEBUG_TYPE "linalg-tiling"
42
43std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
44mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
45 ArrayRef<OpFoldResult> allShapeSizes,
46 ArrayRef<OpFoldResult> allTileSizes) {
47 assert(allTileSizes.size() == map.getNumResults());
48 // Apply `map` to get shape sizes in loop order.
49 SmallVector<OpFoldResult> shapeSizes =
50 makeComposedFoldedMultiResultAffineApply(b, loc, map, operands: allShapeSizes);
51 SmallVector<OpFoldResult> tileSizes(allTileSizes);
52
53 // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
54 LoopIndexToRangeIndexMap loopIndexToRangeIndex;
55 for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
56 if (getConstantIntValue(ofr: tileSizes[idx - zerosCount]) ==
57 static_cast<int64_t>(0)) {
58 shapeSizes.erase(CI: shapeSizes.begin() + idx - zerosCount);
59 tileSizes.erase(CI: tileSizes.begin() + idx - zerosCount);
60 ++zerosCount;
61 continue;
62 }
63 loopIndexToRangeIndex[idx] = idx - zerosCount;
64 }
65
66 // Create a new range with the applied tile sizes.
67 SmallVector<Range, 4> res;
68 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx)
69 res.push_back(Elt: Range{.offset: b.getIndexAttr(value: 0), .size: shapeSizes[idx], .stride: tileSizes[idx]});
70 return std::make_tuple(args&: res, args&: loopIndexToRangeIndex);
71}
72
73void mlir::linalg::transformIndexOps(
74 RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
75 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
76 SmallVector<Value> allIvs(op.getNumLoops(), nullptr);
77 for (auto en : enumerate(First&: allIvs)) {
78 auto rangeIndex = loopIndexToRangeIndex.find(Val: en.index());
79 if (rangeIndex == loopIndexToRangeIndex.end())
80 continue;
81 en.value() = ivs[rangeIndex->second];
82 }
83 offsetIndices(b, linalgOp: op, offests: getAsOpFoldResult(values: allIvs));
84}
85
86/// Asserts that the given index-typed value is strictly positive. If the value
87/// is an attribute, asserts at compile time, otherwise emits an assertion
88/// checked at runtime.
89static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
90 OpFoldResult value) {
91 if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val&: value)) {
92 assert(cast<IntegerAttr>(attr).getValue().isStrictlyPositive() &&
93 "expected strictly positive tile size and divisor");
94 return;
95 }
96
97 Value zero = b.create<arith::ConstantIndexOp>(args: 0);
98 Value condition = b.create<arith::CmpIOp>(args: arith::CmpIPredicate::sgt,
99 args: cast<Value>(Val&: value), args&: zero);
100 b.create<cf::AssertOp>(
101 args&: condition,
102 args: b.getStringAttr(bytes: "expected strictly positive tile size and divisor"));
103}
104
105FailureOr<StaticContinuousTileSizeSpecification>
106mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
107 unsigned targetSize) {
108
109 assert(!op.hasDynamicShape() &&
110 "cannot compute static multi-tile sizes for an op with dynamic shape");
111 assert(targetSize > 0 && "target size must be non-negative");
112 assert(dimension < op.getNumLoops() && "dimension overflow");
113
114 StaticContinuousTileSizeSpecification spec;
115 int64_t loopRange = op.getStaticLoopRanges()[dimension];
116 int64_t tripCount = loopRange / targetSize;
117
118 unsigned tileSize = targetSize;
119
120 spec.tileSizes.push_back(Elt: tileSize);
121 spec.tripCounts.push_back(Elt: tripCount);
122
123 int64_t remainderChunk = loopRange % targetSize;
124
125 while (tileSize > 1 && remainderChunk != 0) {
126
127 uint64_t maxPower = llvm::bit_floor(Value: tileSize);
128 tileSize = maxPower == tileSize ? maxPower >> 1 : maxPower;
129
130 tripCount = remainderChunk / tileSize;
131
132 if (tripCount > 0) {
133 spec.tileSizes.push_back(Elt: tileSize);
134 spec.tripCounts.push_back(Elt: tripCount);
135 }
136
137 remainderChunk = remainderChunk % tileSize;
138 }
139
140 auto tripCountCheck = [&](SmallVector<int64_t> tileSizes,
141 SmallVector<int64_t> tripCounts,
142 int64_t range) -> bool {
143 int64_t computedRange = 0;
144 for (auto [tileSize, tripCount] : llvm::zip(t&: tileSizes, u&: tripCounts))
145 computedRange += tileSize * tripCount;
146 return range == computedRange;
147 };
148
149 if (!tripCountCheck(spec.tileSizes, spec.tripCounts, loopRange))
150 return failure();
151
152 return spec;
153}
154
155FailureOr<ContinuousTileSizeSpecification>
156mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
157 unsigned dimension,
158 OpFoldResult targetSize,
159 bool emitAssertions) {
160
161 SmallVector<Range> loopRanges = op.getIterationDomain(b&: builder);
162 unsigned numLoops = loopRanges.size();
163
164 // Bail out on dimension overflow.
165 if (dimension >= numLoops)
166 return failure();
167
168 // The code below works only on values.
169 Location loc = op->getLoc();
170 ImplicitLocOpBuilder b(loc, builder);
171 if (emitAssertions) {
172 emitIsPositiveIndexAssertion(b, value: targetSize);
173 }
174 Value targetSizeValue =
175 getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: targetSize);
176
177 // Find the trip count of the iteration space dimension for which the tile
178 // sizes are computed.
179 Value loopRange =
180 getValueOrCreateConstantIndexOp(b, loc, ofr: loopRanges[dimension].size);
181 ContinuousTileSizeSpecification spec;
182
183 // Compute the tile sizes and the respective numbers of tiles.
184 AffineExpr s0 = b.getAffineSymbolExpr(position: 0);
185 AffineExpr s1 = b.getAffineSymbolExpr(position: 1);
186 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
187 return affine::makeComposedAffineApply(b, loc: b.getLoc(), e: expr, operands: ofrs);
188 };
189
190 Value tripCountValue = apply(s0.floorDiv(other: s1), {loopRange, targetSizeValue});
191 Value remainderChunkValue = apply(s0 % s1, {loopRange, targetSizeValue});
192
193 OpFoldResult tripCountSize = affine::makeComposedFoldedAffineApply(
194 b, loc: b.getLoc(), expr: s0.floorDiv(other: s1), operands: {loopRange, targetSizeValue});
195
196 // emitAssertions above already asserts that targetSize is
197 // a poistive integer.
198 uint64_t tileSizeInt = *getConstantIntValue(ofr: targetSizeValue);
199
200 assert(tileSizeInt > 0 && "target size must be non-negative");
201
202 spec.tileSizes.push_back(Elt: targetSizeValue);
203 spec.tripCounts.push_back(Elt: tripCountValue);
204
205 while (tileSizeInt > 1) {
206 uint64_t maxPower = llvm::bit_floor(Value: tileSizeInt);
207 tileSizeInt = maxPower == tileSizeInt ? maxPower >> 1 : maxPower;
208 auto constStepOp =
209 builder.createOrFold<arith::ConstantIndexOp>(location: b.getLoc(), args&: tileSizeInt);
210 tripCountValue = apply(s0.floorDiv(other: s1), {remainderChunkValue, constStepOp});
211
212 tripCountSize = affine::makeComposedFoldedAffineApply(
213 b, loc: b.getLoc(), expr: s0.floorDiv(other: s1), operands: {remainderChunkValue, constStepOp});
214
215 // Optimization if tripCount can be determined to be zero.
216 if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(Val&: tripCountSize)) {
217 auto intAttr = cast<IntegerAttr>(Val&: attr);
218 bool isTripCountZero = intAttr.getValue().isZero();
219
220 if (!isTripCountZero) {
221 spec.tileSizes.push_back(Elt: constStepOp);
222 spec.tripCounts.push_back(Elt: tripCountValue);
223 }
224 } else {
225 spec.tileSizes.push_back(Elt: constStepOp);
226 spec.tripCounts.push_back(Elt: tripCountValue);
227 }
228
229 remainderChunkValue = apply(s0 % s1, {remainderChunkValue, constStepOp});
230 }
231
232 return spec;
233}
234
235FailureOr<StaticMultiSizeSpecification>
236mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension,
237 int64_t targetSize, int64_t divisor) {
238 assert(!op.hasDynamicShape() &&
239 "cannot compute static multi-tile sizes for an op with dynamic shape");
240 assert(targetSize > 0 && "target size must be non-negative");
241 assert(divisor > 0 && "divisor must be non-negative");
242 assert(dimension < op.getNumLoops() && "dimension overflow");
243
244 StaticMultiSizeSpecification spec;
245 int64_t tripCount = op.getStaticLoopRanges()[dimension];
246 int64_t a = tripCount / divisor;
247 int64_t t = (targetSize + divisor - 1) / divisor;
248 int64_t totalTripCount = (a + t - 1) / t;
249 spec.lowTileSize = (a / totalTripCount) * divisor;
250 spec.highTileSize = spec.lowTileSize + divisor;
251 spec.highTripCount = a % totalTripCount;
252 spec.lowTripCount = totalTripCount - spec.highTripCount;
253 if (spec.lowTileSize * spec.lowTripCount +
254 spec.highTileSize * spec.highTripCount !=
255 tripCount) {
256 return failure();
257 }
258 return spec;
259}
260
261FailureOr<MultiSizeSpecification>
262mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op,
263 unsigned dimension, OpFoldResult targetSize,
264 OpFoldResult divisor, bool emitAssertions) {
265 // Bail out on dimension overflow.
266 if (dimension >= op.getNumLoops())
267 return failure();
268
269 // The code below works only on values.
270 Location loc = op.getLoc();
271 ImplicitLocOpBuilder b(loc, builder);
272 if (emitAssertions) {
273 emitIsPositiveIndexAssertion(b, value: targetSize);
274 emitIsPositiveIndexAssertion(b, value: divisor);
275 }
276 Value targetSizeValue =
277 getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: targetSize);
278 Value divisorValue = getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: divisor);
279
280 // Find the trip count of the iteration space dimension for which the tile
281 // sizes are computed.
282 SmallVector<OpFoldResult> allShapes =
283 op.createFlatListOfOperandDims(b, b.getLoc());
284 AffineMap shapesToLoops = op.getShapesToLoopsMap();
285 SmallVector<OpFoldResult> loopRanges =
286 makeComposedFoldedMultiResultAffineApply(b, loc: op.getLoc(), map: shapesToLoops,
287 operands: allShapes);
288 Value tripCount =
289 getValueOrCreateConstantIndexOp(b, loc: op.getLoc(), ofr: loopRanges[dimension]);
290
291 // Compute the tile sizes and the respective numbers of tiles.
292 AffineExpr s0 = b.getAffineSymbolExpr(position: 0);
293 AffineExpr s1 = b.getAffineSymbolExpr(position: 1);
294 AffineExpr s2 = b.getAffineSymbolExpr(position: 2);
295 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
296 return affine::makeComposedAffineApply(b, loc: b.getLoc(), e: expr, operands: ofrs);
297 };
298 Value a = apply(s0.floorDiv(other: s1), {tripCount, divisorValue});
299 Value t = apply((s0 + s1 - 1).floorDiv(other: s1), {targetSizeValue, divisorValue});
300 Value d = apply((s0 + s1 - 1).floorDiv(other: s1), {a, t});
301 Value s = apply(s0.floorDiv(other: s1) * s2, {a, d, divisorValue});
302 Value v = apply(s0 % s1, {a, d});
303 Value u = apply(s0 - s1, {d, v});
304
305 MultiSizeSpecification spec;
306 spec.lowTileSize = s;
307 spec.highTileSize = apply(s0 + s1, {s, divisorValue});
308 spec.lowTripCount = u;
309 spec.highTripCount = v;
310
311 // If requested, emit the check that the tile sizes are computed correctly.
312 // For example, for iteration dimension size of 15 and the target size 8 it is
313 // impossible to find two tile sizes both divisible by 8 that fully cover the
314 // original space dimension.
315 if (emitAssertions) {
316 AffineExpr s3 = builder.getAffineSymbolExpr(position: 3);
317 Value coveredSize =
318 apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount,
319 spec.highTileSize, spec.highTripCount});
320 Value equals = b.create<arith::CmpIOp>(args: arith::CmpIPredicate::eq,
321 args&: coveredSize, args&: tripCount);
322 b.create<cf::AssertOp>(
323 args&: equals, args: builder.getStringAttr(
324 bytes: "could not compute dynamic multi-size tile shapes"));
325 }
326
327 return spec;
328}
329
330/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
331/// than `iterationSize`.
332static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
333 OpFoldResult numThreads,
334 OpFoldResult iterationSize) {
335 std::optional<int64_t> tileSizeConst = getConstantIntValue(ofr: tileSize);
336 std::optional<int64_t> numThreadsConst = getConstantIntValue(ofr: numThreads);
337 std::optional<int64_t> iterSizeConst = getConstantIntValue(ofr: iterationSize);
338 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
339 return false;
340 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
341}
342
343/// Build an `affine_max` of all the `vals`.
344static OpFoldResult buildMax(OpBuilder &b, Location loc,
345 ArrayRef<OpFoldResult> vals) {
346 return affine::makeComposedFoldedAffineMax(
347 b, loc, map: AffineMap::getMultiDimIdentityMap(numDims: vals.size(), context: loc.getContext()),
348 operands: vals);
349}
350
351/// Build an `affine_min` of all the `vals`.
352static OpFoldResult buildMin(OpBuilder &b, Location loc,
353 ArrayRef<OpFoldResult> vals) {
354 return affine::makeComposedFoldedAffineMin(
355 b, loc, map: AffineMap::getMultiDimIdentityMap(numDims: vals.size(), context: loc.getContext()),
356 operands: vals);
357}
358
359/// Fill out the `tiledOffsets` and `tiledSizes` to be used to tile to a given
360/// number of threads.
361static void calculateTileOffsetsAndSizes(
362 RewriterBase &b, Location loc, scf::ForallOp forallOp,
363 ArrayRef<OpFoldResult> numThreads, SmallVector<Range> loopRanges,
364 bool omitTileOffsetBoundsCheck,
365 std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
366 SmallVector<OpFoldResult> &tiledOffsets,
367 SmallVector<OpFoldResult> &tiledSizes) {
368 OpBuilder::InsertionGuard g(b);
369 b.setInsertionPointToStart(forallOp.getBody(idx: 0));
370
371 SmallVector<Value> threadIds = forallOp.getInductionVars();
372 SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
373 C&: numThreads, Pred: [](OpFoldResult ofr) { return !isZeroInteger(v: ofr); });
374 int64_t nLoops = loopRanges.size();
375 tiledOffsets.reserve(N: nLoops);
376 tiledSizes.reserve(N: nLoops);
377 for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
378 bool overflow = loopIdx >= numThreads.size();
379 bool isZero = !overflow && isZeroInteger(v: numThreads[loopIdx]);
380 // Degenerate case: take the whole domain.
381 if (overflow || isZero) {
382 tiledOffsets.push_back(Elt: loopRanges[loopIdx].offset);
383 tiledSizes.push_back(Elt: loopRanges[loopIdx].size);
384 continue;
385 }
386
387 // Tiled case: compute the offset and size.
388 AffineExpr i, j, m, n, o;
389 bindDims(ctx: b.getContext(), exprs&: i, exprs&: j);
390 bindSymbols(ctx: b.getContext(), exprs&: m, exprs&: n, exprs&: o);
391 OpFoldResult size = loopRanges[loopIdx].size;
392 OpFoldResult offset = loopRanges[loopIdx].offset;
393 OpFoldResult threadId = threadIds[threadIdIdx];
394 // Symbolic fixed max size per thread.
395 // TODO: floor + 0/1 depending on case for better load-balancing.
396 OpFoldResult tileSizePerThread =
397 nominalTileSizes.has_value()
398 ? (*nominalTileSizes)[loopIdx]
399 : makeComposedFoldedAffineApply(
400 b, loc, expr: m.ceilDiv(other: n),
401 operands: ArrayRef<OpFoldResult>{size, nonZeroNumThreads[threadIdIdx]});
402
403 // Dynamic offset shifted by threadId * maxSizePerThread.
404 OpFoldResult offsetPerThread = makeComposedFoldedAffineApply(
405 b, loc, expr: i + j * m, operands: {offset, threadId, tileSizePerThread});
406 // Dynamic upper-bound depending on the threadId.
407 OpFoldResult residualTileSize = makeComposedFoldedAffineApply(
408 b, loc, expr: i + j * m - n,
409 operands: {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
410 if (!isZeroInteger(v: residualTileSize)) {
411 OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
412 b, loc, expr: -i + m, operands: {offsetPerThread, size});
413 tileSizePerThread =
414 buildMin(b, loc, vals: {sizeMinusOffsetPerThread, tileSizePerThread});
415 }
416
417 tiledOffsets.push_back(Elt: offsetPerThread);
418 // TODO: if tileSizePerThread <= 0 early exit.
419 if (!omitTileOffsetBoundsCheck &&
420 !canOmitTileOffsetInBoundsCheck(tileSize: tileSizePerThread,
421 numThreads: nonZeroNumThreads[threadIdIdx], iterationSize: size))
422 tileSizePerThread =
423 buildMax(b, loc, vals: {b.getIndexAttr(value: 0), tileSizePerThread});
424
425 tiledSizes.push_back(Elt: tileSizePerThread);
426 ++threadIdIdx;
427 }
428}
429
430template <typename LoopTy>
431static FailureOr<TiledLinalgOp>
432tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
433 const LinalgTilingOptions &options) {
434 OpBuilder::InsertionGuard g(b);
435
436 auto nLoops = op.getNumLoops();
437 // Initial tile sizes may be too big, only take the first nLoops.
438 tileSizes = tileSizes.take_front(N: nLoops);
439
440 if (llvm::all_of(tileSizes, [](OpFoldResult ofr) {
441 return getConstantIntValue(ofr) == static_cast<int64_t>(0);
442 })) {
443 TiledLinalgOp tiledOp;
444 tiledOp.op = cast<LinalgOp>(Val: b.clone(op&: *op.getOperation()));
445 tiledOp.tensorResults.assign(in_start: tiledOp.op->result_begin(),
446 in_end: tiledOp.op->result_end());
447 return tiledOp;
448 }
449
450 // 1. Build the tiled loop ranges.
451 SmallVector<OpFoldResult> allShapeSizes =
452 op.createFlatListOfOperandDims(b, op.getLoc());
453 AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
454 if (!shapeSizesToLoopsMap)
455 return failure();
456
457 auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges(
458 b, loc: op.getLoc(), map: shapeSizesToLoopsMap, allShapeSizes, allTileSizes: tileSizes);
459
460 SmallVector<utils::IteratorType, 4> iteratorTypes;
461 for (const auto &attr : enumerate(First: op.getIteratorTypesArray())) {
462 if (loopIndexToRangeIndex.count(Val: attr.index()))
463 iteratorTypes.push_back(Elt: attr.value());
464 }
465 // If interchangeVector is empty, use the identity. Build the permutation map
466 // otherwise.
467 auto invPermutationMap =
468 AffineMap::getMultiDimIdentityMap(numDims: tileSizes.size(), context: b.getContext());
469 if (!options.interchangeVector.empty()) {
470 // Based on the pruned iterations (due to zero tile size), recompute the
471 // interchange vector.
472 SmallVector<unsigned, 4> interchangeVector;
473 interchangeVector.reserve(N: options.interchangeVector.size());
474 for (auto pos : options.interchangeVector) {
475 auto it = loopIndexToRangeIndex.find(Val: pos);
476 if (it == loopIndexToRangeIndex.end())
477 continue;
478 interchangeVector.push_back(Elt: it->second);
479 }
480 // Interchange vector is guaranteed to be a permutation,
481 // `inversePermutation` must succeed.
482 invPermutationMap = inversePermutation(
483 map: AffineMap::getPermutationMap(permutation: interchangeVector, context: b.getContext()));
484 assert(invPermutationMap);
485 SmallVector<int64_t> permutation(interchangeVector.begin(),
486 interchangeVector.end());
487 applyPermutationToVector(inVec&: loopRanges, permutation);
488 applyPermutationToVector(inVec&: iteratorTypes, permutation);
489 }
490
491 // Handle distribution. Create a vector of the same size of loops that are to
492 // be tiled.
493 SmallVector<linalg::ProcInfo> procInfo;
494 if (options.distribution) {
495 procInfo.resize(
496 N: iteratorTypes.size(),
497 NV: linalg::ProcInfo{.procId: nullptr, .nprocs: nullptr, .distributionMethod: linalg::DistributionMethod::None});
498 // Collect loop ranges of tiled loops, loops that are parallel.
499 SmallVector<Range> parallelLoopRanges;
500 for (const auto &iteratorType : llvm::enumerate(First&: iteratorTypes)) {
501 if (!isParallelIterator(iteratorType: iteratorType.value()))
502 break;
503 parallelLoopRanges.push_back(Elt: loopRanges[iteratorType.index()]);
504 }
505 auto returnedProcInfo =
506 options.distribution->procInfo(b, op.getLoc(), parallelLoopRanges);
507 unsigned procIdIdx = 0;
508 // Update the distribution information for the loops.
509 for (const auto &iteratorType : llvm::enumerate(First&: iteratorTypes)) {
510 if (!isParallelIterator(iteratorType: iteratorType.value()))
511 break;
512 procInfo[iteratorType.index()] = returnedProcInfo[procIdIdx++];
513 }
514 }
515
516 // 2. Create the tiled loops.
517 LinalgOp res = op;
518 SmallVector<Value, 4> ivs, tensorResults;
519 auto tiledLoopBodyBuilder =
520 [&](OpBuilder &builder, Location loc, ValueRange localIvs,
521 ValueRange operandValuesToUse) -> scf::ValueVector {
522 ivs.assign(in_start: localIvs.begin(), in_end: localIvs.end());
523
524 // When an `interchangeVector` is present, it has been applied to the
525 // loop ranges and the iterator types. Apply its inverse to the
526 // resulting loop `ivs` to match the op definition.
527 SmallVector<Value, 4> interchangedIvs;
528 if (!options.interchangeVector.empty()) {
529 for (AffineExpr result : invPermutationMap.getResults())
530 interchangedIvs.push_back(
531 Elt: ivs[cast<AffineDimExpr>(Val&: result).getPosition()]);
532 } else {
533 interchangedIvs.assign(in_start: ivs.begin(), in_end: ivs.end());
534 }
535
536 // Tile the `operandValuesToUse` that either match the `op` operands
537 // themselves or the tile loop arguments forwarding them.
538 assert(operandValuesToUse.size() ==
539 static_cast<size_t>(op->getNumOperands()) &&
540 "expect the number of operands and inputs and outputs to match");
541 SmallVector<Value> valuesToTile = operandValuesToUse;
542 SmallVector<OpFoldResult> sizeBounds =
543 makeComposedFoldedMultiResultAffineApply(b, loc, map: shapeSizesToLoopsMap,
544 operands: allShapeSizes);
545 SmallVector<Value> tiledOperands = makeTiledShapes(
546 builder&: b, loc, linalgOp: op, valuesToTile, ivs: getAsOpFoldResult(values: interchangedIvs), tileSizes,
547 sizeBounds,
548 /*omitPartialTileCheck=*/false);
549
550 SmallVector<Type> resultTensorTypes =
551 getTensorOutputTypes(op, operands: tiledOperands);
552 res = clone(b, op, newResultTypes: resultTensorTypes, newOperands: tiledOperands);
553 tensorResults =
554 insertSlicesBack(builder, loc, op, operands: tiledOperands, results: res->getResults());
555 return scf::ValueVector(tensorResults.begin(), tensorResults.end());
556 };
557 GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes,
558 tiledLoopBodyBuilder, procInfo);
559
560 // 3. Transform IndexOp results w.r.t. the tiling.
561 transformIndexOps(b, op: res, ivs, loopIndexToRangeIndex);
562
563 // 4. Gather the newly created loops and return them with the new op.
564 SmallVector<Operation *, 8> loops;
565 loops.reserve(N: ivs.size());
566 for (auto iv : ivs) {
567 if (isa<BlockArgument>(Val: iv)) {
568 loops.push_back(Elt: cast<BlockArgument>(Val&: iv).getOwner()->getParentOp());
569 assert(loops.back() && "no owner found for induction variable!");
570 } else {
571 // TODO: Instead of doing this, try to recover the ops used instead of the
572 // loop.
573 loops.push_back(Elt: nullptr);
574 }
575 }
576
577 // 5. Get the tensor results from the outermost loop if available. Otherwise
578 // use the previously captured `tensorResults`.
579 Operation *outermostLoop = nullptr;
580 for (Operation *loop : loops)
581 if ((outermostLoop = loop))
582 break;
583
584 return TiledLinalgOp{
585 .op: res, .loops: loops, .tensorResults: outermostLoop ? outermostLoop->getResults() : tensorResults};
586}
587
588FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
589 RewriterBase &b, PartialReductionOpInterface op,
590 ArrayRef<OpFoldResult> numThreads, ArrayRef<OpFoldResult> tileSizes,
591 std::optional<ArrayAttr> mapping) {
592 Location loc = op.getLoc();
593 OpBuilder::InsertionGuard g(b);
594
595 // Ops implementing PartialReductionOpInterface are expected to implement
596 // TilingInterface.
597 // TODO: proper core mechanism to tie interfaces together.
598 auto tilingInterfaceOp = cast<TilingInterface>(Val: op.getOperation());
599
600 // Ops implementing PartialReductionOpInterface are not necessarily expected
601 // to implement TilingInterface.. This cast is unsafe atm.
602 // TODO: proper core mechanism to tie interfaces together.
603 // TODO: this function requires a pair of interfaces ..
604 auto destinationStyleOp =
605 dyn_cast<DestinationStyleOpInterface>(Val: op.getOperation());
606 if (!destinationStyleOp)
607 return b.notifyMatchFailure(arg&: op, msg: "not a destination style op");
608
609 // Actually this only work for Linalg ops atm.
610 auto linalgOp = dyn_cast<linalg::LinalgOp>(Val: op.getOperation());
611 if (!linalgOp)
612 return b.notifyMatchFailure(arg&: op, msg: "not a linalg op");
613
614 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
615 if (op->getNumResults() != 1)
616 return b.notifyMatchFailure(
617 arg&: op, msg: "don't support ops with multiple results for now");
618
619 SmallVector<utils::IteratorType> iterators =
620 tilingInterfaceOp.getLoopIteratorTypes();
621 SmallVector<unsigned> redDims;
622 linalgOp.getReductionDims(res&: redDims);
623 if (redDims.size() != 1)
624 return b.notifyMatchFailure(
625 arg&: op, msg: "only support ops with one reduction dimension.");
626 if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
627 return b.notifyMatchFailure(arg&: op, msg: "if tile sizes are present it must have as "
628 "many elements as number of threads");
629
630 if (redDims.front() >= numThreads.size())
631 return b.notifyMatchFailure(
632 arg&: op, msg: "reduction dimension must be mapped to threads");
633
634 // 1. Create the inital tensor value.
635 unsigned reductionDim = redDims.front();
636 SetVector<unsigned> reductionDims;
637 reductionDims.insert(X: reductionDim);
638 FailureOr<SmallVector<Value>> maybeInitTensors =
639 op.generateInitialTensorForPartialReduction(b, loc, tileSizes: numThreads,
640 reductionDims);
641 if (failed(Result: maybeInitTensors))
642 return b.notifyMatchFailure(
643 arg&: op, msg: "Failed to create inital tensors for partial reduction");
644 SmallVector<Value> &initTensors = maybeInitTensors.value();
645
646 // Gather destination tensors.
647 SmallVector<Value> dest;
648 if (failed(Result: tensor::getOrCreateDestinations(b, loc, op, result&: dest)))
649 return b.notifyMatchFailure(arg&: op, msg: "failed to get destination tensors");
650
651 Operation *tiledOp = nullptr;
652
653 SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
654 C&: numThreads, Pred: [](OpFoldResult ofr) { return !isZeroInteger(v: ofr); });
655 SmallVector<Value> materializedNonZeroNumThreads =
656 getValueOrCreateConstantIndexOp(b, loc, valueOrAttrVec: nonZeroNumThreads);
657
658 // 2. Create the ForallOp with an empty region.
659 scf::ForallOp forallOp = b.create<scf::ForallOp>(
660 location: loc, args: getAsOpFoldResult(values: materializedNonZeroNumThreads), args&: initTensors,
661 args&: mapping);
662
663 // 3. Calculate the tile offsets and sizes for the subsequent loop that will
664 // be nested under `forallOp`.
665 SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
666 calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, loopRanges: iterationDomain,
667 /*omitTileOffsetBoundsCheck =*/false,
668 /*nominalTileSizes=*/std::nullopt, tiledOffsets,
669 tiledSizes);
670
671 // 4b. Clone the tileable op and update its destination operands to use the
672 // output bbArgs of the ForallOp.
673 SmallVector<Value> tilingResults;
674 ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
675 {
676 // 4.a. RAII guard, inserting within forallOp, before terminator.
677 OpBuilder::InsertionGuard g(b);
678 b.setInsertionPoint(forallOp.getTerminator());
679
680 SmallVector<Value> tiledDpsInitOperands;
681 for (Value initOperand : destinationStyleOp.getDpsInits()) {
682 auto *it = llvm::find(Range&: dest, Val: initOperand);
683 assert(it != dest.end() && "dest operand not found in dest");
684 unsigned destNum = std::distance(first: dest.begin(), last: it);
685 SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(value: 1));
686 SmallVector<OpFoldResult> outOffsets(numThreads.size(),
687 b.getIndexAttr(value: 0));
688 SmallVector<OpFoldResult> sizes = tiledSizes;
689 sizes[reductionDim] = b.getIndexAttr(value: 1);
690 outOffsets[reductionDim] = forallOp.getInductionVars()[0];
691 // TODO: use SubsetExtractOpInterface once it is available.
692 tiledDpsInitOperands.push_back(Elt: b.create<tensor::ExtractSliceOp>(
693 location: loc, args: cast<RankedTensorType>(Val: initOperand.getType()),
694 args: destBbArgs[destNum], args&: outOffsets, args&: sizes, args&: strides));
695 }
696
697 // 4.b. Clone the op and update init operands.
698 // We cannot use a IRMapping here because it can replace
699 // different OpOperands with the same value.
700 Operation *clonedOp = b.clone(op&: *op.getOperation());
701 b.modifyOpInPlace(root: clonedOp, callable: [&]() {
702 for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal(
703 t: cast<DestinationStyleOpInterface>(Val: clonedOp).getDpsInitsMutable(),
704 u&: tiledDpsInitOperands)) {
705 initOperandPtr.set(tiledInitValue);
706 }
707 });
708
709 // 5. Tile the cloned op and delete the clone.
710 if (tileSizes.empty()) {
711 FailureOr<TilingResult> tilingResult =
712 cast<TilingInterface>(Val: clonedOp).getTiledImplementation(
713 b, offsets: tiledOffsets, sizes: tiledSizes);
714 if (failed(Result: tilingResult))
715 return clonedOp->emitError(message: "Failed to tile op: ");
716 if (tilingResult->tiledOps.size() != 1) {
717 return clonedOp->emitError(message: "expected a single produced tiled op, got ")
718 << tilingResult->tiledOps.size();
719 }
720 tiledOp = tilingResult->tiledOps.front();
721 tilingResults = tilingResult->tiledValues;
722 } else {
723 LinalgTilingOptions options;
724 FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(
725 b, op: cast<LinalgOp>(Val: clonedOp), tileSizes, options);
726 if (failed(Result: maybeTiled))
727 return b.notifyMatchFailure(arg&: op, msg: "failed tileLinalgOpImpl");
728
729 SmallVector<Value> ids = forallOp.getInductionVars();
730 mapLoopToProcessorIds(forOp: cast<scf::ForOp>(Val: maybeTiled->loops.back()), processorId: ids,
731 numProcessors: materializedNonZeroNumThreads);
732 if (maybeTiled->loops.size() != 1) {
733 return clonedOp->emitError(message: "expected a single produced loop");
734 }
735 tiledOp = maybeTiled->op;
736 tilingResults = maybeTiled->loops.front()->getResults();
737 }
738
739 b.eraseOp(op: clonedOp);
740 }
741
742 // 6. Insert the partial reductions back into a new tensor.
743 for (auto [index, result, bbArg] : llvm::zip(
744 t: llvm::seq<unsigned>(Begin: 0, End: dest.size()), u&: tilingResults, args&: destBbArgs)) {
745 // 6.a. Partial subset information is inserted just before the terminator.
746 OpBuilder::InsertionGuard g(b);
747 b.setInsertionPoint(forallOp.getTerminator());
748
749 SmallVector<OpFoldResult> resultOffsets, resultSizes;
750 if (failed(Result: tilingInterfaceOp.getResultTilePosition(
751 b, resultNumber: index, offsets: tiledOffsets, sizes: tiledSizes, resultOffsets, resultSizes)))
752 return op->emitOpError(message: "output offsets couldn't be calculated");
753 SmallVector<OpFoldResult> resultOffsetsRank, resultSizesRank;
754 int64_t offIdx = 0;
755 int64_t sizeIdx = 0;
756 for (int64_t i = 0, e = numThreads.size(); i < e; ++i) {
757 if (i == reductionDim) {
758 resultOffsetsRank.push_back(Elt: forallOp.getInductionVars()[0]);
759 resultSizesRank.push_back(Elt: b.getIndexAttr(value: 1));
760 continue;
761 }
762 resultOffsetsRank.push_back(Elt: resultOffsets[offIdx++]);
763 resultSizesRank.push_back(Elt: resultSizes[sizeIdx++]);
764 }
765 SmallVector<OpFoldResult> strides(resultSizesRank.size(),
766 b.getIndexAttr(value: 1));
767
768 // 6.b. Parallel insertions are inserted at the end of the combining
769 // terminator.
770 b.setInsertionPointToEnd(forallOp.getTerminator().getBody());
771 b.create<tensor::ParallelInsertSliceOp>(
772 location: loc, args&: result, args: bbArg, args&: resultOffsetsRank, args&: resultSizesRank, args&: strides);
773 }
774
775 // 7. Merge the partial reductions.
776 b.setInsertionPointAfter(forallOp);
777 FailureOr<MergeResult> mergeResult =
778 op.mergeReductions(b, loc, partialReduce: forallOp->getResults(), reductionDims);
779 if (failed(Result: mergeResult)) {
780 return failure();
781 }
782 b.replaceOp(op, newValues: mergeResult->replacements);
783
784 // 8. Return.
785 ForallReductionTilingResult results;
786 results.initialValues = initTensors;
787 results.loops = forallOp;
788 results.parallelTiledOps.push_back(Elt: tiledOp);
789 results.mergeOps.append(RHS: mergeResult->mergeOps);
790 return results;
791}
792
793template <typename LoopTy>
794FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
795 RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) {
796 OpBuilder::InsertionGuard g(b);
797 b.setInsertionPoint(op);
798
799 if (!options.tileSizeComputationFunction)
800 return failure();
801
802 // Enforce the convention that "tiling by zero" skips tiling a particular
803 // dimension. This convention is significantly simpler to handle instead of
804 // adjusting affine maps to account for missing dimensions.
805 auto nLoops = op.getNumLoops();
806 SmallVector<OpFoldResult> tileSizeVector =
807 getAsOpFoldResult(values: options.tileSizeComputationFunction(b, op));
808 if (tileSizeVector.size() < nLoops) {
809 tileSizeVector.append(NumInputs: nLoops - tileSizeVector.size(), Elt: b.getIndexAttr(value: 0));
810 }
811
812 return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
813}
814
815FailureOr<TiledLinalgOp>
816mlir::linalg::tileLinalgOp(RewriterBase &b, LinalgOp op,
817 const LinalgTilingOptions &options) {
818 switch (options.loopType) {
819 case LinalgTilingLoopType::Loops:
820 return tileLinalgOpImpl<scf::ForOp>(b, op, options);
821 case LinalgTilingLoopType::ParallelLoops:
822 return tileLinalgOpImpl<scf::ParallelOp>(b, op, options);
823 default:;
824 }
825 return failure();
826}
827
828namespace {
829/// Helper classes for type list expansion.
830template <typename... OpTypes>
831class CanonicalizationPatternList;
832
833template <>
834class CanonicalizationPatternList<> {
835public:
836 static void insert(RewritePatternSet &patterns) {}
837};
838
839template <typename OpTy, typename... OpTypes>
840class CanonicalizationPatternList<OpTy, OpTypes...> {
841public:
842 static void insert(RewritePatternSet &patterns) {
843 OpTy::getCanonicalizationPatterns(patterns, patterns.getContext());
844 CanonicalizationPatternList<OpTypes...>::insert(patterns);
845 }
846};
847} // namespace
848
849RewritePatternSet
850mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) {
851 RewritePatternSet patterns(ctx);
852 populateLinalgTilingCanonicalizationPatterns(patterns);
853 return patterns;
854}
855
856void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
857 RewritePatternSet &patterns) {
858 auto *ctx = patterns.getContext();
859 affine::AffineApplyOp::getCanonicalizationPatterns(results&: patterns, context: ctx);
860 affine::AffineForOp::getCanonicalizationPatterns(results&: patterns, context: ctx);
861 affine::AffineMinOp::getCanonicalizationPatterns(results&: patterns, context: ctx);
862 affine::AffineMaxOp::getCanonicalizationPatterns(results&: patterns, context: ctx);
863 arith::ConstantIndexOp::getCanonicalizationPatterns(results&: patterns, context: ctx);
864
865 memref::SubViewOp::getCanonicalizationPatterns(results&: patterns, context: ctx);
866 memref::ViewOp::getCanonicalizationPatterns(results&: patterns, context: ctx);
867
868 scf::ForOp::getCanonicalizationPatterns(results&: patterns, context: ctx);
869 scf::ParallelOp::getCanonicalizationPatterns(results&: patterns, context: ctx);
870
871 tensor::CastOp::getCanonicalizationPatterns(results&: patterns, context: ctx);
872 tensor::EmptyOp::getCanonicalizationPatterns(results&: patterns, context: ctx);
873 tensor::ExtractSliceOp::getCanonicalizationPatterns(results&: patterns, context: ctx);
874 tensor::InsertSliceOp::getCanonicalizationPatterns(results&: patterns, context: ctx);
875 tensor::PadOp::getCanonicalizationPatterns(results&: patterns, context: ctx);
876 ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(results&: patterns);
877
878 CanonicalizationPatternList<
879#define GET_OP_LIST
880#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
881 >::insert(patterns);
882}
883

source code of mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp