1//===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
10
11#include "mlir/Analysis/SliceAnalysis.h"
12#include "mlir/Dialect/Affine/IR/AffineOps.h"
13#include "mlir/Dialect/Affine/Utils.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Arith/Utils/Utils.h"
16#include "mlir/Dialect/Linalg/IR/Linalg.h"
17#include "mlir/Dialect/Linalg/Utils/Utils.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/Dialect/Tensor/IR/Tensor.h"
20#include "mlir/Dialect/Utils/IndexingUtils.h"
21#include "mlir/Dialect/Utils/StaticValueUtils.h"
22#include "mlir/Interfaces/TilingInterface.h"
23#include "mlir/Interfaces/ValueBoundsOpInterface.h"
24#include <optional>
25
26using namespace mlir;
27using namespace mlir::linalg;
28
29//===----------------------------------------------------------------------===//
30// Utility methods for implementation of Tiling Interface for Linalg ops
31//===----------------------------------------------------------------------===//
32
33/// Return the SSA values that represent the data point accessed using a given
34/// `indexingMap` for a given point in the iteration space represented by `ivs`.
35static SmallVector<Value> getIndicesForAccess(OpBuilder &b, Location loc,
36 AffineMap indexingMap,
37 ValueRange ivs) {
38 SmallVector<Value> indices;
39 indices.reserve(N: indexingMap.getNumResults());
40 for (auto result : indexingMap.getResults()) {
41 AffineMap m = AffineMap::get(dimCount: indexingMap.getNumDims(),
42 symbolCount: indexingMap.getNumSymbols(), result);
43 Value v = b.create<affine::AffineApplyOp>(loc, m, ivs);
44 indices.push_back(Elt: v);
45 }
46 return indices;
47}
48
49/// Method to inline the payload of a `linalgOp` given the iteration space
50/// point and values for the arguments of the payload.
51static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp,
52 ValueRange ivs, ValueRange argValues) {
53 Block *body = linalgOp.getBlock();
54 IRMapping map;
55 map.map(from: body->getArguments(), to&: argValues);
56 for (auto &op : body->without_terminator()) {
57 if (auto indexOp = dyn_cast<IndexOp>(&op)) {
58 map.map(indexOp.getResult(), ivs[indexOp.getDim()]);
59 continue;
60 }
61 b.clone(op, map);
62 }
63
64 Operation *terminator = body->getTerminator();
65 Location loc = terminator->getLoc();
66 for (const auto &operand : llvm::enumerate(terminator->getOperands())) {
67 Value toStore = map.lookupOrDefault(operand.value());
68 OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index());
69 auto indices = getIndicesForAccess(
70 b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs);
71 b.create<memref::StoreOp>(
72 loc, toStore, linalgOp.getDpsInitOperand(operand.index())->get(),
73 indices);
74 }
75 return success();
76}
77
78//===----------------------------------------------------------------------===//
79// External Model for implementing `TilingInterface` for `LinalgOp`s.
80//===----------------------------------------------------------------------===//
81
82namespace {
83/// External model implementation of TilingInterface for LinalgOps. An external
84/// model implementation is used for now till the use of `TilingInterface` is
85/// on-par with the current Linalg tiling + fusion patterns. Once it is
86/// maybe possible to move this into the op-definition (though there are
87/// advantages to leaving it as an external model)
88template <typename LinalgOpTy>
89struct LinalgOpTilingInterface
90 : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
91 LinalgOpTy> {
92 /// Return the loop iterator type.
93 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
94 LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
95 return concreteOp.getIteratorTypesArray();
96 }
97
98 /// Return the iteration domain range.
99 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
100 OpBuilder::InsertionGuard g(b);
101 b.setInsertionPoint(op);
102 Location loc = op->getLoc();
103 LinalgOp linalgOp = cast<LinalgOp>(op);
104 SmallVector<OpFoldResult> allShapesSizes =
105 linalgOp.createFlatListOfOperandDims(b, loc);
106 AffineMap map = linalgOp.getShapesToLoopsMap();
107
108 return llvm::to_vector(
109 llvm::map_range(map.getResults(), [&](AffineExpr loopExpr) {
110 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
111 b, loc, expr: loopExpr, operands: allShapesSizes);
112 return Range{b.getIndexAttr(0), .size: ofr, b.getIndexAttr(1)};
113 }));
114 }
115
116 /// Instantiate the tiled implementation of the operation.
117 FailureOr<TilingResult>
118 getTiledImplementation(Operation *op, OpBuilder &b,
119 ArrayRef<OpFoldResult> offsets,
120 ArrayRef<OpFoldResult> sizes) const {
121 // Leave the `sizeBounds` value empty. That is only needed when the `sizes`
122 // specified could lead to out of bounds accesses.
123 Location loc = op->getLoc();
124 LinalgOp linalgOp = cast<LinalgOp>(op);
125 SmallVector<Value> valuesToTile = linalgOp->getOperands();
126 SmallVector<Value> tiledOperands = makeTiledShapes(
127 b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
128 SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
129 llvm::make_filter_range(
130 tiledOperands,
131 [](Value v) -> bool {
132 return isa_and_nonnull<tensor::ExtractSliceOp, memref::SubViewOp>(
133 v.getDefiningOp());
134 }),
135 [](Value v) -> Operation * { return v.getDefiningOp(); });
136
137 SmallVector<Type> resultTensorTypes =
138 getTensorOutputTypes(linalgOp, tiledOperands);
139
140 Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
141 offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
142
143 return TilingResult{
144 .tiledOps: {tiledOp}, .tiledValues: SmallVector<Value>(tiledOp->getResults()), .generatedSlices: generatedSlices};
145 }
146
147 /// Utility to fetch the offsets and sizes when applied as per the indexing
148 /// map of the linalg op. This helps in fusing the linalg op as a consumer of
149 /// a given slice op.
150 void
151 getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
152 ArrayRef<OpFoldResult> offsets,
153 ArrayRef<OpFoldResult> sizes,
154 SmallVectorImpl<OpFoldResult> &mappedOffsets,
155 SmallVectorImpl<OpFoldResult> &mappedSizes) const {
156 unsigned numLoops = linalgOp.getNumLoops();
157 auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
158 mappedOffsets.resize(N: numLoops);
159 mappedSizes.resize(N: numLoops);
160 if (!indexingMap.isPermutation()) {
161 SmallVector<Range> iterationDomain =
162 tilingInterfaceOp.getIterationDomain(b);
163 for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
164 mappedOffsets[index] = value.offset;
165 mappedSizes[index] = value.size;
166 }
167 }
168 for (const auto &&[index, value] :
169 llvm::enumerate(First: indexingMap.getResults())) {
170 unsigned dimPosition = cast<AffineDimExpr>(Val: value).getPosition();
171 mappedOffsets[dimPosition] = offsets[index];
172 mappedSizes[dimPosition] = sizes[index];
173 }
174 }
175
176 /// Method to return the position of the result tile computed by the tiled
177 /// operation.
178 LogicalResult getIterationDomainTileFromOperandTile(
179 Operation *op, OpBuilder &b, unsigned operandNumber,
180 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
181 SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
182 SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
183 auto linalgOp = cast<LinalgOp>(op);
184
185 // Check that the indexing map used for the operand is a projected
186 // permutation. This could be relaxed with a more general approach that can
187 // map the offsets and sizes from the operand to iteration space tiles
188 // (filling in full extent for dimensions not used to access the result).
189 AffineMap indexingMap =
190 linalgOp.getMatchingIndexingMap(&op->getOpOperand(idx: operandNumber));
191 if (!indexingMap.isProjectedPermutation()) {
192 return op->emitError()
193 << "unhandled get iter domain position when operand is not "
194 "accessed using a permuted projection";
195 }
196
197 getMappedOffsetAndSize(linalgOp: linalgOp, b, indexingMap, offsets, sizes,
198 mappedOffsets&: iterDomainOffsets, mappedSizes&: iterDomainSizes);
199 return success();
200 }
201
202 /// Return the details of the output tile generated by the tiled
203 /// implementation.
204 LogicalResult
205 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
206 ArrayRef<OpFoldResult> offsets,
207 ArrayRef<OpFoldResult> sizes,
208 SmallVector<OpFoldResult> &resultOffsets,
209 SmallVector<OpFoldResult> &resultSizes) const {
210 Location loc = op->getLoc();
211 LinalgOp linalgOp = cast<LinalgOp>(op);
212
213 AffineExpr d0;
214 bindDims(ctx: b.getContext(), exprs&: d0);
215 SmallVector<OpFoldResult> subShapeSizes =
216 llvm::to_vector(llvm::map_range(sizes, [&](OpFoldResult ofr) {
217 return affine::makeComposedFoldedAffineApply(b, loc, expr: d0 - 1, operands: ofr);
218 }));
219
220 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
221 SliceParameters sliceParams = computeSliceParameters(
222 b, loc, outOperand->get(), sizes,
223 linalgOp.getMatchingIndexingMap(outOperand), offsets,
224 /*ubs*/ {}, subShapeSizes, true);
225 resultOffsets = sliceParams.offsets;
226 resultSizes = sliceParams.sizes;
227 return success();
228 }
229
230 LogicalResult getIterationDomainTileFromResultTile(
231 Operation *op, OpBuilder &b, unsigned resultNumber,
232 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
233 SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
234 SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
235 auto linalgOp = cast<LinalgOp>(op);
236
237 // Check that the indexing map used for the output is a projected
238 // permutation. This could be relaxed with a more general approach that can
239 // map the offsets and sizes from the result to iteration space tiles
240 // (filling in full extent for dimensions not used to access the result).
241 AffineMap indexingMap =
242 linalgOp.getIndexingMapMatchingResult(op->getResult(idx: resultNumber));
243 if (!indexingMap.isProjectedPermutation()) {
244 return op->emitOpError(
245 message: "unhandled tiled implementation generation when result is not "
246 "accessed using a permuted projection");
247 }
248
249 getMappedOffsetAndSize(linalgOp: linalgOp, b, indexingMap, offsets, sizes,
250 mappedOffsets&: iterDomainOffsets, mappedSizes&: iterDomainSizes);
251 return success();
252 }
253
254 FailureOr<TilingResult>
255 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
256 ArrayRef<OpFoldResult> offsets,
257 ArrayRef<OpFoldResult> sizes) const {
258 SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
259 if (failed(getIterationDomainTileFromResultTile(
260 op, b, resultNumber, offsets, sizes, iterDomainOffsets&: mappedOffsets, iterDomainSizes&: mappedSizes))) {
261 return failure();
262 }
263 auto tilingInterfaceOp = cast<TilingInterface>(op);
264 FailureOr<TilingResult> tilingResult =
265 tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
266
267 if (failed(Result: tilingResult))
268 return failure();
269
270 if (tilingResult->tiledOps.size() != 1)
271 return op->emitOpError(message: "failed to generate tiled implementation");
272
273 return TilingResult{
274 .tiledOps: tilingResult->tiledOps,
275 .tiledValues: SmallVector<Value>{tilingResult->tiledValues[resultNumber]},
276 .generatedSlices: tilingResult->generatedSlices};
277 }
278
279 /// Method to generate the tiled implementation of an operation from the tile
280 /// of the operand.
281 FailureOr<TilingResult> getTiledImplementationFromOperandTile(
282 Operation *op, OpBuilder &b, unsigned operandNumber,
283 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
284 SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
285 if (failed(getIterationDomainTileFromOperandTile(
286 op, b, operandNumber, offsets, sizes, iterDomainOffsets&: mappedOffsets,
287 iterDomainSizes&: mappedSizes))) {
288 return failure();
289 }
290 return getTiledImplementation(op, b, offsets: mappedOffsets, sizes: mappedSizes);
291 }
292
293 LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
294 Location loc,
295 ValueRange ivs) const {
296 auto linalgOp = cast<LinalgOp>(op);
297 if (!linalgOp.hasPureBufferSemantics())
298 return op->emitOpError(message: "expected operation to have buffer semantics");
299
300 SmallVector<Value> indexedValues;
301 indexedValues.reserve(N: linalgOp->getNumOperands());
302 Location linalgOpLoc = op->getLoc();
303 /// Load the data corresponding to the block arguments that
304 /// represent input operands.
305 for (OpOperand &operand : linalgOp->getOpOperands()) {
306 if (!linalgOp.payloadUsesValueFromOperand(&operand)) {
307 indexedValues.push_back(nullptr);
308 continue;
309 }
310 if (linalgOp.isScalar(&operand)) {
311 indexedValues.push_back(operand.get());
312 continue;
313 }
314 SmallVector<Value> indices = getIndicesForAccess(
315 builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs);
316 Value load =
317 builder.create<memref::LoadOp>(linalgOpLoc, operand.get(), indices);
318 indexedValues.push_back(load);
319 }
320
321 /// Inline the op payload and store the result.
322 return inlinePayload(builder, linalgOp, ivs, indexedValues);
323 }
324};
325
326//===----------------------------------------------------------------------===//
327// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
328//===----------------------------------------------------------------------===//
329
330/// Return an AffineMap for a partial result for the given result number,
331/// assuming the partial tiling strategy is outer-reduction loop +
332/// inner-parallel tile. The returned AffineMap can be used as the replacement
333/// AffineMap for the inner-parallel tile linalg op for the given result number.
334///
335/// The new AffineMap is the old AffineMap with reduction dimensions appended
336/// at end.
337static AffineMap getPartialResultAffineMap(LinalgOp linalgOp,
338 ArrayRef<int> reductionDims,
339 unsigned resultNumber) {
340 AffineMap map =
341 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber));
342 for (int redPos : reductionDims) {
343 map = map.insertResult(expr: getAffineDimExpr(redPos, linalgOp.getContext()),
344 pos: map.getNumResults());
345 }
346 return map;
347}
348
349/// External model implementation of PartialReductionInterface for
350/// LinalgOps.
351template <typename LinalgOpTy>
352struct LinalgOpPartialReductionInterface
353 : public PartialReductionOpInterface::ExternalModel<
354 LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
355 FailureOr<SmallVector<Value>> generateInitialTensorForPartialReduction(
356 Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
357 ArrayRef<int> reductionDims) const {
358 auto linalgOp = cast<LinalgOp>(op);
359 OpBuilder::InsertionGuard guard(b);
360
361 if (linalgOp.hasPureBufferSemantics())
362 return op->emitOpError(message: "expected operation to have tensor semantics");
363
364 // LinalgOp implements TilingInterface.
365 auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
366 SmallVector<OpFoldResult> shape =
367 llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b),
368 [](Range x) { return x.size; });
369
370 SmallVector<OpFoldResult> tiledShape;
371 for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) {
372 if (isZeroInteger(tileSize)) {
373 tiledShape.push_back(dimSize);
374 } else {
375 tiledShape.push_back(tileSize);
376 }
377 }
378
379 SmallVector<Value> inits;
380 for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e;
381 ++initIdx) {
382 SmallVector<Operation *, 4> combinerOps;
383 if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
384 combinerOps) ||
385 combinerOps.size() != 1)
386 return op->emitOpError(message: "Failed to anaysis the reduction operation.");
387
388 Operation *reductionOp = combinerOps[0];
389 std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
390 if (!identity.has_value())
391 return op->emitOpError(
392 message: "Failed to get an identity value for the reduction operation.");
393
394 // Append the new partial result dimensions.
395 AffineMap partialMap =
396 getPartialResultAffineMap(linalgOp, reductionDims, initIdx);
397 SmallVector<OpFoldResult> partialResultShape;
398 for (AffineExpr dimExpr : partialMap.getResults()) {
399 auto dim = cast<AffineDimExpr>(dimExpr);
400 partialResultShape.push_back(tiledShape[dim.getPosition()]);
401 }
402
403 Type elType =
404 getElementTypeOrSelf(linalgOp->getResult(initIdx).getType());
405 Value emptyTensor =
406 b.create<tensor::EmptyOp>(loc, partialResultShape, elType);
407 Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
408 auto identityTensor =
409 b.create<linalg::FillOp>(loc, constantOp, emptyTensor);
410 inits.push_back(Elt: identityTensor.getResult(0));
411 }
412
413 return inits;
414 }
415
416 FailureOr<TilingResult>
417 tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
418 ValueRange init, ArrayRef<OpFoldResult> offsets,
419 ArrayRef<OpFoldResult> sizes,
420 ArrayRef<int> reductionDims) const {
421 OpBuilder::InsertionGuard guard(b);
422 auto linalgOp = cast<LinalgOp>(op);
423
424 // Step 1. Extend init maps to have reduction dimension dims, since we
425 // are converting them to parallel dimensions.
426 SmallVector<AffineMap> newInitMaps;
427 newInitMaps.reserve(N: linalgOp.getNumDpsInits());
428 for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
429 // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
430 // this with a for range loop when we have it.
431 AffineMap newMap =
432 getPartialResultAffineMap(linalgOp, reductionDims, idx);
433 newInitMaps.push_back(newMap);
434 }
435
436 // Step 2a: Extract a slice of the input operands.
437 SmallVector<Value> tiledInputs = makeTiledShapes(
438 b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true);
439 SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
440 llvm::make_filter_range(
441 tiledInputs, [](Value v) -> bool { return v.getDefiningOp(); }),
442 [](Value v) -> Operation * { return v.getDefiningOp(); });
443
444 // Step 2b: Extract a slice of the init operands.
445 SmallVector<Value, 1> tiledInits;
446 for (auto [valueMap, valueToTile] : llvm::zip_equal(t&: newInitMaps, u&: init)) {
447 int64_t initRank = valueMap.getNumResults();
448 SmallVector<OpFoldResult> initOffset(initRank, b.getIndexAttr(0));
449 SmallVector<OpFoldResult> initStride(initRank, b.getIndexAttr(1));
450 SmallVector<OpFoldResult> initSizes;
451 for (AffineExpr dimExpr : valueMap.getResults()) {
452 auto dim = cast<AffineDimExpr>(Val&: dimExpr);
453 initSizes.push_back(Elt: sizes[dim.getPosition()]);
454 }
455 // TODO: Use SubsetExtractOpInterface here once available.
456 auto extractSlice = b.create<tensor::ExtractSliceOp>(
457 loc, valueToTile, initOffset, initSizes, initStride);
458 tiledInits.push_back(Elt: extractSlice);
459 generatedSlices.push_back(Elt: extractSlice);
460 }
461
462 // Update the indexing maps.
463 SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
464 // Change the init maps.
465 for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
466 // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
467 // this with a for range loop when we have it.
468 OpOperand *initOperand = linalgOp.getDpsInitOperand(idx);
469 int64_t mapIdx = linalgOp.getIndexingMapIndex(initOperand);
470 newMaps[mapIdx] = newInitMaps[idx];
471 }
472
473 // Step 3. Change the reduction dim iterator types.
474 SmallVector<utils::IteratorType> newIteratorTypes =
475 linalgOp.getIteratorTypesArray();
476 for (int dim : reductionDims)
477 newIteratorTypes[dim] = utils::IteratorType::parallel;
478
479 // Step 4. Create the new generic op.
480 auto genericOp =
481 b.create<GenericOp>(loc, ValueRange(tiledInits).getTypes(), tiledInputs,
482 tiledInits, newMaps, newIteratorTypes);
483 IRMapping mapping;
484 op->getRegion(index: 0).cloneInto(&genericOp.getRegion(),
485 genericOp.getRegion().begin(), mapping);
486 return TilingResult{
487 {genericOp.getOperation()},
488 llvm::map_to_vector(genericOp->getResults(),
489 [](OpResult r) -> Value { return r; }),
490 generatedSlices};
491 }
492
493 FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b,
494 Location loc, ValueRange partialReduce,
495 ArrayRef<int> reductionDims) const {
496 auto linalgOp = cast<LinalgOp>(op);
497
498 // Permute the reduction dims as permuted by the partial result map.
499
500 int64_t numInits = linalgOp.getNumDpsInits();
501 SmallVector<Operation *> mergeOperations;
502 SmallVector<Value> replacements;
503 for (int idx : llvm::seq(numInits)) {
504 // linalg.reduce's iteration space is the tiled result's iteration space
505 // (and not the tiled operation's iteration space). To account for this,
506 // permute the reduction dimensions based on the partial result map of the
507 // tiled result.
508 AffineMap partialMap =
509 getPartialResultAffineMap(linalgOp, reductionDims, idx);
510 SmallVector<int64_t> partialReductionDims;
511 for (auto [resultNum, dimExpr] :
512 llvm::enumerate(partialMap.getResults())) {
513 unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
514 if (llvm::is_contained(reductionDims, dim)) {
515 partialReductionDims.push_back(resultNum);
516 }
517 }
518
519 Value partialResult = partialReduce[idx];
520 Value init = linalgOp.getDpsInits()[idx];
521
522 auto reduction = b.create<linalg::ReduceOp>(
523 loc, partialResult, init, partialReductionDims,
524 [&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) {
525 // Get the combiner op.
526 SmallVector<Operation *, 4> combinerOps;
527 matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps);
528 Operation *clonedReductionOp = b.clone(*combinerOps[0]);
529 // Combine the input at idx and output at numInits + idx.
530 clonedReductionOp->setOperand(0, inputs[0]);
531 clonedReductionOp->setOperand(1, inputs[1]);
532 b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
533 });
534
535 mergeOperations.push_back(reduction);
536 replacements.push_back(reduction->getResult(0));
537 }
538
539 return MergeResult{.mergeOps: mergeOperations, .replacements: replacements};
540 }
541
542 LogicalResult getPartialResultTilePosition(
543 Operation *op, OpBuilder &b, unsigned resultNumber,
544 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
545 SmallVector<OpFoldResult> &resultOffsets,
546 SmallVector<OpFoldResult> &resultSizes,
547 ArrayRef<int> reductionDims) const {
548 auto linalgOp = cast<LinalgOp>(op);
549
550 AffineMap partialMap =
551 getPartialResultAffineMap(linalgOp, reductionDims, resultNumber);
552 for (AffineExpr dimExpr : partialMap.getResults()) {
553 unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
554 resultSizes.push_back(sizes[dim]);
555
556 if (llvm::is_contained(reductionDims, dim)) {
557 // Reduction dims are reduced, and are always outputed in the same
558 // place. So use offset 0 for them.
559 resultOffsets.push_back(b.getIndexAttr(0));
560 } else {
561 resultOffsets.push_back(offsets[dim]);
562 }
563 }
564
565 return success();
566 }
567};
568
569template <typename OpTy>
570static SmallVector<Range> getPackUnPackIterationDomain(OpTy op,
571 OpBuilder &builder) {
572 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
573 "applies to only pack or unpack operations");
574 OpBuilder::InsertionGuard g(builder);
575 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
576 : op.getDestRank();
577 OpFoldResult zero = builder.getIndexAttr(0);
578 OpFoldResult one = builder.getIndexAttr(1);
579 ReifiedRankedShapedTypeDims resultShape;
580 (void)reifyResultShapes(builder, op, resultShape);
581 SmallVector<Range> loopBounds(rank);
582 for (auto dim : llvm::seq<int64_t>(Begin: 0, End: rank)) {
583 loopBounds[dim].offset = zero;
584 loopBounds[dim].stride = one;
585 loopBounds[dim].size = resultShape[0][dim];
586 }
587 return loopBounds;
588}
589
590static void applyPermToRange(SmallVector<OpFoldResult> &offsets,
591 SmallVector<OpFoldResult> &sizes,
592 ArrayRef<int64_t> permutation) {
593 if (permutation.empty())
594 return;
595 applyPermutationToVector<OpFoldResult>(inVec&: offsets, permutation);
596 applyPermutationToVector<OpFoldResult>(inVec&: sizes, permutation);
597}
598
599struct PackOpTiling
600 : public TilingInterface::ExternalModel<PackOpTiling, linalg::PackOp> {
601
602 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
603 // Note that here we only consider untiled dimensions and outer tiled data
604 // dimensions, the inner tiled data dimensions are materialized when
605 // building the body of the operation.
606 auto packOp = cast<PackOp>(op);
607 SmallVector<utils::IteratorType> iteratorTypes(
608 packOp.getSourceRank(), utils::IteratorType::parallel);
609 return iteratorTypes;
610 }
611
612 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
613 return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
614 }
615
616 FailureOr<TilingResult>
617 getTiledImplementation(Operation *op, OpBuilder &b,
618 ArrayRef<OpFoldResult> offsets,
619 ArrayRef<OpFoldResult> sizes) const {
620 auto packOp = cast<PackOp>(op);
621 Location loc = packOp.getLoc();
622
623 // The tiling is applied on interchanged dimensions. We have to undo the
624 // interchange to map sizes and offsets to the original input.
625 int64_t inputRank = packOp.getSourceRank();
626 SmallVector<OpFoldResult> origOffsets(offsets);
627 SmallVector<OpFoldResult> origSizes(sizes);
628 applyPermToRange(origOffsets, origSizes,
629 invertPermutationVector(packOp.getOuterDimsPerm()));
630
631 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
632 packOp.getDimAndTileMapping();
633 SmallVector<OpFoldResult> srcDimValues =
634 tensor::getMixedSizes(builder&: b, loc, value: packOp.getSource());
635 SmallVector<OpFoldResult> inputIndices, inputSizes;
636 for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
637 using AV = affine::AffineValueExpr;
638 affine::AffineBuilder ab(b, loc);
639 AffineExpr dim0, dim1, sym;
640 bindDims(b.getContext(), dim0, dim1);
641 bindSymbols(b.getContext(), sym);
642 if (dimAndTileMapping.count(dim)) {
643 // If the data dimension is tiled, the i-th index is the product of
644 // offset_i and tile_i, and the i-th size is the product of sizes_i and
645 // tile_i.
646 auto avOffset = AV(dim0).bind(origOffsets[dim]);
647 auto avSize = AV(dim0).bind(origSizes[dim]);
648 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
649 inputIndices.push_back(ab.mul(avOffset, avTileSize));
650 inputSizes.push_back(ab.mul(avSize, avTileSize));
651 } else {
652 inputIndices.push_back(origOffsets[dim]);
653 inputSizes.push_back(origSizes[dim]);
654 }
655
656 // Limit the size of the input operand for incomplete tiles.
657 if (packOp.getPaddingValue()) {
658 OpFoldResult dimSize = srcDimValues[dim];
659 auto avDimSize = AV(dim0).bind(dimSize);
660 auto avInputIdx = AV(dim1).bind(inputIndices.back());
661 inputSizes.back() =
662 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
663 }
664 }
665
666 auto oneAttr = b.getI64IntegerAttr(1);
667 SmallVector<OpFoldResult> strides(inputRank, oneAttr);
668
669 SmallVector<Value> tiledOperands;
670 auto sourceSlice = b.create<tensor::ExtractSliceOp>(
671 loc, packOp.getSource(), inputIndices, inputSizes, strides);
672 tiledOperands.push_back(Elt: sourceSlice);
673
674 SmallVector<OpFoldResult> outputOffsets, outputSizes;
675 if (failed(Result: getResultTilePosition(op, b, resultNumber: 0, offsets, sizes, resultOffsets&: outputOffsets,
676 resultSizes&: outputSizes)))
677 return {};
678
679 strides.append(packOp.getDestRank() - inputRank, oneAttr);
680 auto outSlice = b.create<tensor::ExtractSliceOp>(
681 loc, packOp.getDest(), outputOffsets, outputSizes, strides);
682 tiledOperands.push_back(Elt: outSlice);
683
684 if (auto val = packOp.getPaddingValue())
685 tiledOperands.push_back(Elt: val);
686 for (auto tile : packOp.getInnerTiles())
687 tiledOperands.push_back(tile);
688
689 Operation *tiledPackOp = b.create<PackOp>(
690 loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
691
692 return TilingResult{
693 .tiledOps: {tiledPackOp},
694 .tiledValues: SmallVector<Value>(tiledPackOp->getResults()),
695 .generatedSlices: llvm::to_vector(Range: ArrayRef<Operation *>{sourceSlice, outSlice})};
696 }
697
698 LogicalResult
699 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
700 ArrayRef<OpFoldResult> offsets,
701 ArrayRef<OpFoldResult> sizes,
702 SmallVector<OpFoldResult> &resultOffsets,
703 SmallVector<OpFoldResult> &resultSizes) const {
704 // The iteration domain is over outer dimensions of packed layout. In this
705 // context, the outer dimensions of `resultOffsets` are `offsets`. The
706 // inner dimensions of `resultOffsets` are zeros because tiling is not
707 // applied to them.
708 auto packOp = cast<PackOp>(op);
709 int64_t inputRank = packOp.getSourceRank();
710 int64_t outputRank = packOp.getDestRank();
711 auto zeroAttr = b.getI64IntegerAttr(0);
712 resultOffsets.assign(in_start: offsets.begin(), in_end: offsets.end());
713 resultOffsets.append(outputRank - inputRank, zeroAttr);
714
715 ReifiedRankedShapedTypeDims outputShape;
716 (void)reifyResultShapes(b, packOp, outputShape);
717 resultSizes.assign(in_start: sizes.begin(), in_end: sizes.end());
718 for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
719 resultSizes.push_back(outputShape[0][dataTileDim]);
720
721 return success();
722 }
723
724 FailureOr<TilingResult>
725 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
726 ArrayRef<OpFoldResult> offsets,
727 ArrayRef<OpFoldResult> sizes) const {
728 auto packOp = cast<PackOp>(op);
729 int64_t numTiles = packOp.getInnerDimsPos().size();
730
731 // tensor.pack op is fusible (as a producer) only if full inner tiles are
732 // iterated or inner dims are not tiled. Otherwise, it will generate a
733 // sequence of non-trivial ops (for partial tiles).
734 for (auto offset : offsets.take_back(numTiles))
735 if (!isZeroInteger(offset))
736 return failure();
737
738 for (auto iter :
739 llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles)))
740 if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
741 return failure();
742
743 FailureOr<TilingResult> tilingResult = getTiledImplementation(
744 op, b, offsets: offsets.drop_back(N: numTiles), sizes: sizes.drop_back(N: numTiles));
745 if (failed(Result: tilingResult))
746 return failure();
747 return tilingResult.value();
748 }
749
750 /// Method to return the position of iteration domain tile computed by the
751 /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
752 /// `resultSizes` only cover outer dimensions.
753 LogicalResult getIterationDomainTileFromOperandTile(
754 Operation *op, OpBuilder &b, unsigned operandNumber,
755 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
756 SmallVectorImpl<OpFoldResult> &resultOffsets,
757 SmallVectorImpl<OpFoldResult> &resultSizes) const {
758 if (operandNumber != 0)
759 return failure();
760
761 auto packOp = cast<PackOp>(op);
762 // It is not trivial to infer dest tile from source tile if `packOp` has
763 // padding semantic.
764 if (packOp.getPaddingValue())
765 return failure();
766
767 Location loc = packOp.getLoc();
768
769 SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
770 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
771 packOp.getDimAndTileMapping();
772 for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
773 if (dimAndTileMapping.count(dim)) {
774 FailureOr<int64_t> cstSize =
775 ValueBoundsConstraintSet::computeConstantBound(
776 presburger::BoundType::UB, sizes[dim],
777 /*stopCondition=*/nullptr, /*closedUB=*/true);
778 std::optional<int64_t> cstInnerSize =
779 getConstantIntValue(dimAndTileMapping[dim]);
780 // Currently fusing `packOp` as consumer only expects perfect tiling
781 // scenario because even if without padding semantic, the `packOp` may
782 // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
783 // where the `tileSize` from operand of `packOp` is 5, which is not
784 // exactly divided by `innerTile`(=6) of `packOp`. As the result:
785 // 1. the first slice is extracted from (0) to (4) and inserted into
786 // (0,0)~(0,4) at first row.
787 // 2. the second slice is extracted from (5) to (9) and SHOULD BE
788 // respectively inserted into two rows with different length, including
789 // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
790 // them, thus adding below constraint to bypass them temporarily. In
791 // another word, we can only support tiling with consumer if the tile
792 // size for the producer is a multiple of the inner tile size for the
793 // packed dimensions at this moment.
794 if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
795 return failure();
796 }
797
798 using AV = affine::AffineValueExpr;
799 affine::AffineBuilder ab(b, loc);
800 AffineExpr dim0, sym;
801 bindDims(b.getContext(), dim0);
802 bindSymbols(b.getContext(), sym);
803 auto avOffset = AV(dim0).bind(offsets[dim]);
804 auto avSize = AV(dim0).bind(sizes[dim]);
805 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
806 outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
807 outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
808 } else {
809 outerDimOffsets.push_back(offsets[dim]);
810 outerDimSizes.push_back(sizes[dim]);
811 }
812 }
813 applyPermToRange(outerDimOffsets, outerDimSizes, packOp.getOuterDimsPerm());
814 resultOffsets = outerDimOffsets;
815 resultSizes = outerDimSizes;
816 return success();
817 }
818
819 /// Method to return the tiled implementation of tensor.pack as a consumer.
820 FailureOr<TilingResult> getTiledImplementationFromOperandTile(
821 Operation *op, OpBuilder &b, unsigned operandNumber,
822 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
823 if (operandNumber != 0)
824 return failure();
825
826 auto packOp = cast<PackOp>(op);
827 Location loc = packOp.getLoc();
828
829 int64_t inputRank = packOp.getSourceRank();
830 auto oneAttr = b.getI64IntegerAttr(1);
831 SmallVector<OpFoldResult> strides(inputRank, oneAttr);
832
833 SmallVector<Value> tiledOperands;
834 auto sourceSlice = b.create<tensor::ExtractSliceOp>(
835 loc, packOp.getSource(), offsets, sizes, strides);
836 tiledOperands.push_back(Elt: sourceSlice);
837
838 SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
839 if (failed(Result: getIterationDomainTileFromOperandTile(
840 op, b, /*operandNumber=*/0, offsets, sizes, resultOffsets&: outerDimOffsets,
841 resultSizes&: outerDimSizes)))
842 return failure();
843
844 SmallVector<OpFoldResult> outputOffsets, outputSizes;
845 if (failed(Result: getResultTilePosition(op, b, resultNumber: 0, offsets: outerDimOffsets, sizes: outerDimSizes,
846 resultOffsets&: outputOffsets, resultSizes&: outputSizes)))
847 return failure();
848
849 strides.append(packOp.getDestRank() - inputRank, oneAttr);
850 auto outSlice = b.create<tensor::ExtractSliceOp>(
851 loc, packOp.getDest(), outputOffsets, outputSizes, strides);
852 tiledOperands.push_back(Elt: outSlice);
853
854 assert(!packOp.getPaddingValue() && "Expect no padding semantic");
855 for (auto tile : packOp.getInnerTiles())
856 tiledOperands.push_back(tile);
857
858 Operation *tiledPackOp = b.create<PackOp>(
859 loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
860
861 return TilingResult{
862 .tiledOps: {tiledPackOp},
863 .tiledValues: SmallVector<Value>(tiledPackOp->getResults()),
864 .generatedSlices: llvm::to_vector(Range: ArrayRef<Operation *>{sourceSlice, outSlice})};
865 }
866};
867
868struct UnpackTileDimInfo {
869 bool isAlignedToInnerTileSize;
870 OpFoldResult sourceOffset;
871 OpFoldResult sourceSize;
872 OpFoldResult resultOffset;
873 OpFoldResult destExpandedSize;
874};
875
876/// Returns the needed information for tiling unpack op on `tileDim` with given
877/// `tileOffset` and `tileSize`. For more details, see the comment of the
878/// `getTiledImplementation`.
879static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
880 int64_t tileDim,
881 OpFoldResult tileOffset,
882 OpFoldResult tileSize) {
883 UnpackTileDimInfo info;
884 Attribute zeroAttr = b.getIndexAttr(0);
885 Attribute oneAttr = b.getIndexAttr(1);
886 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
887 unpackOp.getDimAndTileMapping();
888 // The dimension is not one of packed data dimension.
889 if (!dimAndTileMapping.count(Val: tileDim)) {
890 info.isAlignedToInnerTileSize = true;
891 info.sourceOffset = tileOffset;
892 info.sourceSize = tileSize;
893 info.resultOffset = zeroAttr;
894 info.destExpandedSize = tileSize;
895 return info;
896 }
897
898 Location loc = unpackOp.getLoc();
899 using AV = affine::AffineValueExpr;
900 affine::AffineBuilder ab(b, loc);
901 AffineExpr dim0, dim1, sym0;
902 bindDims(ctx: b.getContext(), exprs&: dim0, exprs&: dim1);
903 bindSymbols(ctx: b.getContext(), exprs&: sym0);
904
905 OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
906
907 info.isAlignedToInnerTileSize = false;
908 FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
909 type: presburger::BoundType::UB, var: tileSize,
910 /*stopCondition=*/nullptr, /*closedUB=*/true);
911 std::optional<int64_t> cstInnerSize = getConstantIntValue(ofr: innerTileSize);
912 if (!failed(Result: cstSize) && cstInnerSize) {
913 if (*cstSize % *cstInnerSize == 0)
914 info.isAlignedToInnerTileSize = true;
915
916 // If the tiling size equals to the inner tiling size, the outer dims are
917 // always 1.
918 if (*cstInnerSize == *cstSize) {
919 auto lhs = AV(dim0).bind(v: tileOffset);
920 auto rhs = AV(dim1).bind(v: innerTileSize);
921 info.sourceOffset = ab.floor(lhs, rhs: rhs);
922 info.sourceSize = oneAttr;
923 info.resultOffset = zeroAttr;
924 info.destExpandedSize = tileSize;
925 return info;
926 }
927 }
928
929 if (info.isAlignedToInnerTileSize) {
930 info.sourceOffset =
931 ab.floor(lhs: AV(dim0).bind(v: tileOffset), rhs: AV(dim1).bind(v: innerTileSize));
932 info.resultOffset = zeroAttr;
933 info.destExpandedSize = tileSize;
934
935 // The ceilDiv is needed here because there could be incomplete tile even
936 // it is perfect tiling cases. E.g.,
937 // %0 = unpack tensor<33x2xf32> into tensor<64xf32>
938 // If the tiling size is 32, there will be 3 tiles. Two of them have
939 // size=32; one of them have size=2. The size is represented using
940 // affine_min op; we need ceilDiv.
941 info.sourceSize =
942 ab.ceil(lhs: AV(dim0).bind(v: tileSize), rhs: AV(dim1).bind(v: innerTileSize));
943 return info;
944 }
945
946 affine::DivModValue firstCoord = affine::getDivMod(
947 b, loc, lhs: getValueOrCreateConstantIndexOp(b, loc, ofr: tileOffset),
948 rhs: getValueOrCreateConstantIndexOp(b, loc, ofr: innerTileSize));
949 OpFoldResult tileExclusiveBound =
950 ab.add(lhs: AV(dim0).bind(v: tileOffset), rhs: AV(dim1).bind(v: tileSize));
951 affine::DivModValue lastCoord = affine::getDivMod(
952 b, loc,
953 lhs: getValueOrCreateConstantIndexOp(
954 b, loc,
955 ofr: ab.sub(lhs: AV(dim0).bind(v: tileExclusiveBound), rhs: AV(dim1).bind(v: oneAttr))),
956 rhs: getValueOrCreateConstantIndexOp(b, loc, ofr: innerTileSize));
957
958 OpFoldResult lengthMinusOne = ab.sub(lhs: AV(dim0).bind(v: lastCoord.quotient),
959 rhs: AV(dim1).bind(v: firstCoord.quotient));
960 info.sourceSize =
961 ab.add(lhs: AV(dim0).bind(v: lengthMinusOne), rhs: AV(dim1).bind(v: oneAttr));
962 info.sourceOffset = firstCoord.quotient;
963 info.resultOffset = firstCoord.remainder;
964 // Do not create an Affine ops for expanded size because the affine op is too
965 // complicated which would trigger an issue in affine ops simplification.
966 info.destExpandedSize = b.createOrFold<arith::MulIOp>(
967 loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize),
968 getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
969 return info;
970}
971
972struct UnPackOpTiling
973 : public TilingInterface::ExternalModel<UnPackOpTiling, linalg::UnPackOp> {
974
975 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
976 auto unpackOp = cast<UnPackOp>(op);
977 SmallVector<utils::IteratorType> iteratorTypes(
978 unpackOp.getDestRank(), utils::IteratorType::parallel);
979 return iteratorTypes;
980 }
981
982 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
983 return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b);
984 }
985
986 /// There are two cases in tiling unpack ops. If the tiling size is aligned to
987 /// the inner tile size, the corresponding tiles of source are all complete.
988 /// Otherwise, there are in-complete tiles. We will need to expand the slice
989 /// of source for getting complete tiles. The tiled unpack op unpacks more
990 /// data from source, so We'll need an extract_slice op to shift and truncate
991 /// the output.
992 /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The
993 /// coordinates of second tile (i.e., result[15..31]) are
994 /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last
995 /// row are incomplete tiles. To represent the unpack op, we have to complete
996 /// the rows. I.e., the input coordinates would start with (1, 0); end with
997 /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements
998 /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we
999 /// can get the actual result.
1000 FailureOr<TilingResult>
1001 getTiledImplementation(Operation *op, OpBuilder &b,
1002 ArrayRef<OpFoldResult> offsets,
1003 ArrayRef<OpFoldResult> sizes) const {
1004 auto unpackOp = cast<UnPackOp>(op);
1005 int64_t srcRank = unpackOp.getSourceRank();
1006 int64_t destRank = unpackOp.getDestRank();
1007 int64_t numInnerTiles = srcRank - destRank;
1008 Location loc = unpackOp.getLoc();
1009
1010 // The perfect tiling case indicates that the tiling sizes are multiple of
1011 // inner_tile_size. In this context, no extra data is needed when
1012 // representing the tiled unpack op.
1013 bool isPerfectTilingCase = true;
1014 Attribute oneAttr = b.getIndexAttr(1);
1015 SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr);
1016 SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes;
1017 SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest;
1018 for (auto dim : llvm::seq<int64_t>(0, destRank)) {
1019 UnpackTileDimInfo info =
1020 getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
1021 if (!info.isAlignedToInnerTileSize)
1022 isPerfectTilingCase = false;
1023 sliceSrcIndices.push_back(info.sourceOffset);
1024 sliceSrcSizes.push_back(info.sourceSize);
1025 destExpandedSizes.push_back(info.destExpandedSize);
1026 resultOffsetsFromDest.push_back(info.resultOffset);
1027 }
1028
1029 // The tiling is applied on destination dimensions. We have to apply the
1030 // interchange on source dimensions if outer_dims_perm is set.
1031 applyPermToRange(sliceSrcIndices, sliceSrcSizes,
1032 unpackOp.getOuterDimsPerm());
1033 Attribute zeroAttr = b.getIndexAttr(0);
1034 sliceSrcIndices.append(NumInputs: numInnerTiles, Elt: zeroAttr);
1035 sliceSrcSizes.append(unpackOp.getMixedTiles());
1036 sliceSrcStrides.append(NumInputs: numInnerTiles, Elt: oneAttr);
1037 SmallVector<Operation *> generatedSlices;
1038 tensor::ExtractSliceOp sliceSource = b.create<tensor::ExtractSliceOp>(
1039 loc, unpackOp.getSource(), sliceSrcIndices, sliceSrcSizes,
1040 sliceSrcStrides);
1041 generatedSlices.push_back(Elt: sliceSource);
1042
1043 SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
1044 Value sliceDest;
1045 if (isPerfectTilingCase) {
1046 auto destSliceOp = b.create<tensor::ExtractSliceOp>(
1047 loc, unpackOp.getDest(), offsets, sizes, destStrides);
1048 sliceDest = destSliceOp;
1049 generatedSlices.push_back(Elt: destSliceOp);
1050 } else {
1051 sliceDest = b.create<tensor::EmptyOp>(
1052 loc, destExpandedSizes, unpackOp.getDestType().getElementType());
1053 }
1054
1055 SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest};
1056 for (auto tile : unpackOp.getInnerTiles())
1057 tiledOperands.push_back(tile);
1058
1059 Operation *tiledUnpackOp = b.create<UnPackOp>(
1060 loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs());
1061
1062 if (isPerfectTilingCase)
1063 return TilingResult{.tiledOps: {tiledUnpackOp},
1064 .tiledValues: SmallVector<Value>(tiledUnpackOp->getResults()),
1065 .generatedSlices: generatedSlices};
1066
1067 auto extractSlice = b.create<tensor::ExtractSliceOp>(
1068 loc, tiledUnpackOp->getResult(idx: 0), resultOffsetsFromDest, sizes,
1069 destStrides);
1070 return TilingResult{
1071 {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
1072 }
1073
1074 LogicalResult
1075 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
1076 ArrayRef<OpFoldResult> offsets,
1077 ArrayRef<OpFoldResult> sizes,
1078 SmallVector<OpFoldResult> &resultOffsets,
1079 SmallVector<OpFoldResult> &resultSizes) const {
1080 resultOffsets = llvm::to_vector(Range&: offsets);
1081 resultSizes = llvm::to_vector(Range&: sizes);
1082 return success();
1083 }
1084
1085 FailureOr<TilingResult>
1086 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
1087 ArrayRef<OpFoldResult> offsets,
1088 ArrayRef<OpFoldResult> sizes) const {
1089 FailureOr<TilingResult> tilingResult =
1090 getTiledImplementation(op, b, offsets, sizes);
1091 if (failed(Result: tilingResult))
1092 return failure();
1093 return tilingResult.value();
1094 }
1095
1096 /// Method to return the position of iteration domain tile computed by the
1097 /// tiled operation.
1098 LogicalResult getIterationDomainTileFromOperandTile(
1099 Operation *op, OpBuilder &b, unsigned operandNumber,
1100 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
1101 SmallVectorImpl<OpFoldResult> &resultOffsets,
1102 SmallVectorImpl<OpFoldResult> &resultSizes) const {
1103 auto unPackOp = cast<UnPackOp>(op);
1104 // If the operand tile is the dest, then no adjustment is needed.
1105 if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
1106 resultOffsets = llvm::to_vector(Range&: offsets);
1107 resultSizes = llvm::to_vector(Range&: sizes);
1108 return success();
1109 }
1110 Location loc = unPackOp.getLoc();
1111
1112 int64_t numTiles = unPackOp.getInnerDimsPos().size();
1113 auto destOffsets = offsets.drop_back(N: numTiles);
1114 auto destSizes = sizes.drop_back(N: numTiles);
1115 // The tiling is applied on interchanged dimensions. We have to undo the
1116 // interchange to map sizes and offsets to the original input.
1117 int64_t outputRank = unPackOp.getDestRank();
1118 ReifiedRankedShapedTypeDims reifiedReturnShapes;
1119 if (failed(reifyResultShapes(b, unPackOp, reifiedReturnShapes)))
1120 return failure();
1121 SmallVector<OpFoldResult> outputMixedSizes = reifiedReturnShapes.front();
1122 SmallVector<OpFoldResult> origOffsets(destOffsets);
1123 SmallVector<OpFoldResult> origSizes(destSizes);
1124 applyPermToRange(origOffsets, origSizes,
1125 invertPermutationVector(unPackOp.getOuterDimsPerm()));
1126
1127 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1128 unPackOp.getDimAndTileMapping();
1129
1130 for (auto dim : llvm::seq<int64_t>(0, outputRank)) {
1131 using AV = affine::AffineValueExpr;
1132 affine::AffineBuilder ab(b, loc);
1133 AffineExpr dim0, dim1, sym0;
1134 bindDims(b.getContext(), dim0, dim1);
1135 bindSymbols(b.getContext(), sym0);
1136 if (dimAndTileMapping.count(dim)) {
1137 // If the data dimension is tiled, the i-th index is the product of
1138 // offset_i and tile_i, and the i-th size is the product of sizes_i and
1139 // tile_i. The sizes must be clamped to the sizes of the unpack result.
1140 auto avOffset = AV(dim0).bind(origOffsets[dim]);
1141 auto avSize = AV(dim0).bind(origSizes[dim]);
1142 auto avTileSize = AV(sym0).bind(dimAndTileMapping[dim]);
1143 auto avResultSize = AV(dim0).bind(outputMixedSizes[dim]);
1144 resultOffsets.push_back(ab.mul(avOffset, avTileSize));
1145 auto avResultOffset = AV(dim1).bind(resultOffsets.back());
1146 resultSizes.push_back(ab.min({ab.mul(avSize, avTileSize),
1147 ab.sub(avResultSize, avResultOffset)}));
1148 } else {
1149 resultOffsets.push_back(origOffsets[dim]);
1150 resultSizes.push_back(origSizes[dim]);
1151 }
1152 }
1153 return success();
1154 }
1155
1156 /// Method to return the tiled implementation of tensor.unpack as a consumer.
1157 FailureOr<TilingResult> getTiledImplementationFromOperandTile(
1158 Operation *op, OpBuilder &b, unsigned operandNumber,
1159 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
1160 auto unPackOp = cast<UnPackOp>(op);
1161 // tensor.unpack op is fusible (as a consumer) only if inner dims are not
1162 // tiled.
1163 int64_t numTiles = unPackOp.getInnerDimsPos().size();
1164 for (auto iter :
1165 llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) {
1166 if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
1167 return failure();
1168 }
1169
1170 Location loc = unPackOp.getLoc();
1171
1172 // Fetch offset/size for creating the slice of the dest operand of
1173 // unpack op.
1174 SmallVector<OpFoldResult> outputOffsets, outputSizes;
1175 if (failed(Result: getIterationDomainTileFromOperandTile(
1176 op, b, /*operandNumber=*/0, offsets, sizes, resultOffsets&: outputOffsets,
1177 resultSizes&: outputSizes)))
1178 return failure();
1179
1180 auto oneAttr = b.getI64IntegerAttr(1);
1181 int64_t outputRank = unPackOp.getDestRank();
1182 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
1183
1184 SmallVector<Value> tiledOperands;
1185 // Create slice of the dest operand.
1186 auto extractDestSlice = b.create<tensor::ExtractSliceOp>(
1187 loc, unPackOp.getDest(), outputOffsets, outputSizes, strides);
1188 tiledOperands.push_back(Elt: extractDestSlice);
1189
1190 strides.append(unPackOp.getSourceRank() - outputRank, oneAttr);
1191 // Create slice of the source operand.
1192 auto extractSourceSlice = b.create<tensor::ExtractSliceOp>(
1193 loc, unPackOp.getSource(), offsets, sizes, strides);
1194 tiledOperands.insert(tiledOperands.begin(), extractSourceSlice);
1195 for (auto tile : unPackOp.getInnerTiles())
1196 tiledOperands.push_back(tile);
1197
1198 // Create tiled unpack op.
1199 Operation *tiledUnPackOp =
1200 b.create<UnPackOp>(loc, TypeRange{extractDestSlice.getType()},
1201 tiledOperands, op->getAttrs());
1202
1203 return TilingResult{.tiledOps: {tiledUnPackOp},
1204 .tiledValues: SmallVector<Value>(tiledUnPackOp->getResults()),
1205 .generatedSlices: llvm::to_vector(Range: ArrayRef<Operation *>{
1206 extractSourceSlice, extractDestSlice})};
1207 }
1208};
1209
1210} // namespace
1211
1212template <typename OpType>
1213static void registerOne(MLIRContext *ctx) {
1214 OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
1215 OpType::template attachInterface<LinalgOpPartialReductionInterface<OpType>>(
1216 *ctx);
1217}
1218
1219/// Variadic helper function.
1220template <typename... OpTypes>
1221static void registerAll(MLIRContext *ctx) {
1222 (registerOne<OpTypes>(ctx), ...);
1223}
1224
1225#define GET_OP_LIST
1226
1227void mlir::linalg::registerTilingInterfaceExternalModels(
1228 DialectRegistry &registry) {
1229 registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
1230 registerOne<linalg::GenericOp>(ctx);
1231 linalg::PackOp::attachInterface<PackOpTiling>(*ctx);
1232 linalg::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
1233 registerAll<
1234#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
1235 >(ctx);
1236 });
1237}
1238
1239void mlir::linalg::registerTilingInterfaceExternalModelsForPackUnPackOps(
1240 DialectRegistry &registry) {
1241 registry.addExtension(extensionFn: +[](MLIRContext *ctx, LinalgDialect *dialect) {
1242 linalg::PackOp::attachInterface<PackOpTiling>(*ctx);
1243 linalg::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
1244 });
1245}
1246

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp