1//===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
10
11#include "mlir/Analysis/SliceAnalysis.h"
12#include "mlir/Dialect/Affine/IR/AffineOps.h"
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Arith/Utils/Utils.h"
15#include "mlir/Dialect/Linalg/IR/Linalg.h"
16#include "mlir/Dialect/Linalg/Utils/Utils.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/Tensor/IR/Tensor.h"
19#include "mlir/Dialect/Utils/StaticValueUtils.h"
20#include "mlir/Interfaces/TilingInterface.h"
21#include <optional>
22
23using namespace mlir;
24using namespace mlir::linalg;
25
26//===----------------------------------------------------------------------===//
27// Utility methods for implementation of Tiling Interface for Linalg ops
28//===----------------------------------------------------------------------===//
29
30/// Return the SSA values that represent the data point accessed using a given
31/// `indexingMap` for a given point in the iteration space represented by `ivs`.
32static SmallVector<Value> getIndicesForAccess(OpBuilder &b, Location loc,
33 AffineMap indexingMap,
34 ValueRange ivs) {
35 SmallVector<Value> indices;
36 indices.reserve(N: indexingMap.getNumResults());
37 for (auto result : indexingMap.getResults()) {
38 AffineMap m = AffineMap::get(dimCount: indexingMap.getNumDims(),
39 symbolCount: indexingMap.getNumSymbols(), result);
40 Value v = b.create<affine::AffineApplyOp>(loc, m, ivs);
41 indices.push_back(Elt: v);
42 }
43 return indices;
44}
45
46/// Method to inline the payload of a `linalgOp` given the iteration space
47/// point and values for the arguments of the payload.
48static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp,
49 ValueRange ivs, ValueRange argValues) {
50 Block *body = linalgOp.getBlock();
51 IRMapping map;
52 map.map(from: body->getArguments(), to&: argValues);
53 for (auto &op : body->without_terminator()) {
54 if (auto indexOp = dyn_cast<IndexOp>(&op)) {
55 map.map(indexOp.getResult(), ivs[indexOp.getDim()]);
56 continue;
57 }
58 b.clone(op, map);
59 }
60
61 Operation *terminator = body->getTerminator();
62 Location loc = terminator->getLoc();
63 for (const auto &operand : llvm::enumerate(terminator->getOperands())) {
64 Value toStore = map.lookupOrDefault(operand.value());
65 OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index());
66 auto indices = getIndicesForAccess(
67 b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs);
68 b.create<memref::StoreOp>(
69 loc, toStore, linalgOp.getDpsInitOperand(operand.index())->get(),
70 indices);
71 }
72 return success();
73}
74
75//===----------------------------------------------------------------------===//
76// External Model for implementing `TilingInterface` for `LinalgOp`s.
77//===----------------------------------------------------------------------===//
78
79namespace {
80/// External model implementation of TilingInterface for LinalgOps. An external
81/// model implementation is used for now till the use of `TilingInterface` is
82/// on-par with the current Linalg tiling + fusion patterns. Once it is
83/// maybe possible to move this into the op-definition (though there are
84/// advantages to leaving it as an external model)
85template <typename LinalgOpTy>
86struct LinalgOpTilingInterface
87 : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
88 LinalgOpTy> {
89 /// Return the loop iterator type.
90 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
91 LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
92 return concreteOp.getIteratorTypesArray();
93 }
94
95 /// Return the iteration domain range.
96 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
97 OpBuilder::InsertionGuard g(b);
98 b.setInsertionPoint(op);
99 Location loc = op->getLoc();
100 LinalgOp linalgOp = cast<LinalgOp>(op);
101 SmallVector<OpFoldResult> allShapesSizes =
102 linalgOp.createFlatListOfOperandDims(b, loc);
103 AffineMap map = linalgOp.getShapesToLoopsMap();
104
105 return llvm::to_vector(
106 llvm::map_range(map.getResults(), [&](AffineExpr loopExpr) {
107 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
108 b, loc, expr: loopExpr, operands: allShapesSizes);
109 return Range{b.getIndexAttr(0), .size: ofr, b.getIndexAttr(1)};
110 }));
111 }
112
113 // Instantiate the tiled implementation of the operation.
114 FailureOr<TilingResult>
115 getTiledImplementation(Operation *op, OpBuilder &b,
116 ArrayRef<OpFoldResult> offsets,
117 ArrayRef<OpFoldResult> sizes) const {
118 // Leave the `sizeBounds` value empty. That is only needed when the `sizes`
119 // specified could lead to out of bounds accesses.
120 Location loc = op->getLoc();
121 LinalgOp linalgOp = cast<LinalgOp>(op);
122 SmallVector<Value> valuesToTile = linalgOp->getOperands();
123 SmallVector<Value, 4> tiledOperands = makeTiledShapes(
124 b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
125
126 SmallVector<Type> resultTensorTypes =
127 getTensorOutputTypes(linalgOp, tiledOperands);
128
129 Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
130 offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
131
132 return TilingResult{.tiledOps: {tiledOp}, .tiledValues: SmallVector<Value>(tiledOp->getResults())};
133 }
134
135 // Return the details of the output tile generated by the tiled
136 // implementation.
137 LogicalResult
138 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
139 ArrayRef<OpFoldResult> offsets,
140 ArrayRef<OpFoldResult> sizes,
141 SmallVector<OpFoldResult> &resultOffsets,
142 SmallVector<OpFoldResult> &resultSizes) const {
143 Location loc = op->getLoc();
144 LinalgOp linalgOp = cast<LinalgOp>(op);
145
146 AffineExpr d0;
147 bindDims(ctx: b.getContext(), exprs&: d0);
148 SmallVector<OpFoldResult> subShapeSizes =
149 llvm::to_vector(llvm::map_range(sizes, [&](OpFoldResult ofr) {
150 return affine::makeComposedFoldedAffineApply(b, loc, expr: d0 - 1, operands: ofr);
151 }));
152
153 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
154 SliceParameters sliceParams = computeSliceParameters(
155 b, loc, outOperand->get(), sizes,
156 linalgOp.getMatchingIndexingMap(outOperand), offsets,
157 /*ubs*/ {}, subShapeSizes, true);
158 resultOffsets = sliceParams.offsets;
159 resultSizes = sliceParams.sizes;
160 return success();
161 }
162
163 FailureOr<TilingResult>
164 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
165 ArrayRef<OpFoldResult> offsets,
166 ArrayRef<OpFoldResult> sizes) const {
167 auto linalgOp = cast<LinalgOp>(op);
168
169 // Check that the indexing map used for the output is a projected
170 // permutation. This could be relaxed with a more general approach that can
171 // map the offsets and sizes from the result to iteration space tiles
172 // (filling in full extent for dimensions not used to access the result).
173 AffineMap indexingMap =
174 linalgOp.getIndexingMapMatchingResult(op->getResult(idx: resultNumber));
175 if (!indexingMap.isProjectedPermutation()) {
176 return op->emitOpError(
177 message: "unhandled tiled implementation generation when result is not "
178 "accessed using a permuted projection");
179 }
180
181 auto numLoops = linalgOp.getNumLoops();
182 auto tilingInterfaceOp = cast<TilingInterface>(op);
183 SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
184 iterationTileSizes(numLoops);
185 if (!indexingMap.isPermutation()) {
186 SmallVector<Range> iterationDomain =
187 tilingInterfaceOp.getIterationDomain(b);
188 for (const auto &range : llvm::enumerate(iterationDomain)) {
189 iterationTileOffsets[range.index()] = range.value().offset;
190 iterationTileSizes[range.index()] = range.value().size;
191 }
192 }
193 for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
194 unsigned dimPosition =
195 cast<AffineDimExpr>(resultExpr.value()).getPosition();
196 iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
197 iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
198 }
199
200 FailureOr<TilingResult> tilingResult =
201 tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
202 iterationTileSizes);
203 if (tilingResult->tiledOps.size() != 1)
204 return op->emitOpError(message: "failed to generate tiled implementation");
205
206 return TilingResult{
207 .tiledOps: tilingResult->tiledOps,
208 .tiledValues: SmallVector<Value>{tilingResult->tiledValues[resultNumber]}};
209 }
210
211 LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
212 Location loc,
213 ValueRange ivs) const {
214 auto linalgOp = cast<LinalgOp>(op);
215 if (!linalgOp.hasPureBufferSemantics())
216 return op->emitOpError(message: "expected operation to have buffer semantics");
217
218 SmallVector<Value> indexedValues;
219 indexedValues.reserve(N: linalgOp->getNumOperands());
220 Location linalgOpLoc = op->getLoc();
221 /// Load the data corresponding to the block arguments that
222 /// represent input operands.
223 for (OpOperand &operand : linalgOp->getOpOperands()) {
224 if (!linalgOp.payloadUsesValueFromOperand(&operand)) {
225 indexedValues.push_back(nullptr);
226 continue;
227 }
228 if (linalgOp.isScalar(&operand)) {
229 indexedValues.push_back(operand.get());
230 continue;
231 }
232 SmallVector<Value> indices = getIndicesForAccess(
233 builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs);
234 Value load =
235 builder.create<memref::LoadOp>(linalgOpLoc, operand.get(), indices);
236 indexedValues.push_back(load);
237 }
238
239 /// Inline the op payload and store the result.
240 return inlinePayload(builder, linalgOp, ivs, indexedValues);
241 }
242};
243
244//===----------------------------------------------------------------------===//
245// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
246//===----------------------------------------------------------------------===//
247
248/// External model implementation of PartialReductionInterface for LinalgOps.
249template <typename LinalgOpTy>
250struct LinalgOpPartialReductionInterface
251 : public PartialReductionOpInterface::ExternalModel<
252 LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
253 FailureOr<Operation *> generateInitialTensorForPartialReduction(
254 Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
255 ArrayRef<int> reductionDims) const {
256 auto linalgOp = cast<LinalgOp>(op);
257 OpBuilder::InsertionGuard guard(b);
258
259 if (linalgOp.hasPureBufferSemantics())
260 return op->emitOpError(message: "expected operation to have tensor semantics");
261 // Insert the new parallel dimension based on the index of the reduction
262 // loops. This could be controlled by user for more flexibility.
263
264 SmallVector<Operation *, 4> combinerOps;
265 if (!matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps) ||
266 combinerOps.size() != 1)
267 return op->emitOpError(message: "Failed to anaysis the reduction operation.");
268
269 Operation *reductionOp = combinerOps[0];
270 std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
271 if (!identity.has_value())
272 return op->emitOpError(
273 message: "Failed to get an identity value for the reduction operation.");
274
275 ArrayRef<int64_t> oldShape =
276 linalgOp.getShape(linalgOp.getDpsInitOperand(0));
277
278 // Calculate the new shape, we insert the new dimensions based on the index
279 // of the reduction dimensions.
280 SmallVector<int64_t> newOutputShape;
281 SmallVector<Value> dynamicDims;
282 int64_t currReductionDims = 0;
283 DenseSet<int> reductionDimsSet(reductionDims.begin(), reductionDims.end());
284 for (int64_t idx :
285 llvm::seq<int64_t>(0, oldShape.size() + reductionDims.size())) {
286 if (reductionDimsSet.contains(idx)) {
287 dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape);
288 currReductionDims++;
289 continue;
290 }
291 int64_t oldIdx = idx - currReductionDims;
292 int64_t dim = oldShape[oldIdx];
293 newOutputShape.push_back(dim);
294 if (ShapedType::isDynamic(dim))
295 dynamicDims.push_back(b.create<tensor::DimOp>(
296 loc, linalgOp.getDpsInitOperand(0)->get(), oldIdx));
297 }
298 Value emptyTensor = b.create<tensor::EmptyOp>(
299 loc, newOutputShape, linalgOp.getRegionOutputArgs()[0].getType(),
300 dynamicDims);
301 Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
302 auto identityTensor =
303 b.create<linalg::FillOp>(loc, constantOp, emptyTensor);
304 return identityTensor.getOperation();
305 }
306
307 Operation *tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
308 ValueRange init,
309 ArrayRef<OpFoldResult> offsets,
310 ArrayRef<OpFoldResult> sizes,
311 ArrayRef<int> reductionDims) const {
312 OpBuilder::InsertionGuard guard(b);
313 auto linalgOp = cast<LinalgOp>(op);
314
315 AffineMap oldOutputMap =
316 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
317 SmallVector<AffineExpr> outputExpr(oldOutputMap.getNumResults() +
318 reductionDims.size());
319
320 for (int idx : reductionDims)
321 outputExpr[idx] = b.getAffineDimExpr(position: idx);
322 int currExpr = 0;
323 for (int idx : llvm::seq<int>(0, outputExpr.size())) {
324 if (outputExpr[idx])
325 continue;
326 outputExpr[idx] = oldOutputMap.getResult(currExpr++);
327 }
328
329 // Step 1: Extract a slice of the input operands.
330 SmallVector<Value> valuesToTile = linalgOp.getDpsInputs();
331 SmallVector<Value, 4> tiledOperands = makeTiledShapes(
332 b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
333
334 // Step 2: Extract the accumulator operands
335 SmallVector<OpFoldResult> strides(offsets.size(), b.getIndexAttr(1));
336 SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
337 // TODO: use SubsetExtractOpInterface once it is available.
338 Value out = b.create<tensor::ExtractSliceOp>(loc, init[0], outOffsets,
339 sizes, strides);
340
341 // Step3. Create a generic op where the reduction dimensions are replaced
342 // by a parallel dimension of the size of reduction.
343 SmallVector<utils::IteratorType> newIteratorTypes =
344 linalgOp.getIteratorTypesArray();
345 for (int dim : reductionDims)
346 newIteratorTypes[dim] = utils::IteratorType::parallel;
347 SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
348 newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr,
349 linalgOp.getContext());
350 auto genericOp =
351 b.create<GenericOp>(loc, TypeRange({out.getType()}), tiledOperands,
352 ValueRange({out}), newMaps, newIteratorTypes);
353 IRMapping mapping;
354 op->getRegion(index: 0).cloneInto(&genericOp.getRegion(),
355 genericOp.getRegion().begin(), mapping);
356 return genericOp.getOperation();
357 }
358
359 Operation *mergeReductions(Operation *op, OpBuilder &b, Location loc,
360 ValueRange partialReduce,
361 ArrayRef<int> reductionDims) const {
362 auto linalgOp = cast<LinalgOp>(op);
363
364 DenseSet<int> reductionDimsSet(reductionDims.begin(), reductionDims.end());
365
366 // Then create a new reduction that only reduce the newly added dimensions
367 // from the previous op.
368 int64_t intermRank = cast<ShapedType>(partialReduce[0].getType()).getRank();
369 AffineMap inputMap = b.getMultiDimIdentityMap(rank: intermRank);
370 SmallVector<utils::IteratorType> reductionIteratorTypes;
371 SmallVector<AffineExpr> exprs;
372
373 for (int64_t i : llvm::seq<int64_t>(Begin: 0, End: intermRank)) {
374 if (reductionDimsSet.contains(V: i)) {
375 reductionIteratorTypes.push_back(utils::IteratorType::reduction);
376 } else {
377 exprs.push_back(Elt: b.getAffineDimExpr(position: i));
378 reductionIteratorTypes.push_back(utils::IteratorType::parallel);
379 }
380 }
381
382 AffineMap outputMap =
383 AffineMap::get(dimCount: intermRank, symbolCount: 0, results: exprs, context: op->getContext());
384 SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
385
386 SmallVector<Operation *, 4> combinerOps;
387 matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps);
388 Operation *reductionOp = combinerOps[0];
389
390 auto reduction = b.create<GenericOp>(
391 loc, op->getResultTypes(), ValueRange({partialReduce[0]}),
392 linalgOp.getDpsInits(), reductionMaps, reductionIteratorTypes,
393 [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
394 Operation *clonedReductionOp = b.clone(*reductionOp);
395 clonedReductionOp->setOperand(0, inputs[0]);
396 clonedReductionOp->setOperand(1, inputs[1]);
397 b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
398 });
399 return reduction.getOperation();
400 }
401};
402
403} // namespace
404
405template <typename OpType>
406static void registerOne(MLIRContext *ctx) {
407 OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
408 OpType::template attachInterface<LinalgOpPartialReductionInterface<OpType>>(
409 *ctx);
410}
411
412/// Variadic helper function.
413template <typename... OpTypes>
414static void registerAll(MLIRContext *ctx) {
415 (registerOne<OpTypes>(ctx), ...);
416}
417
418#define GET_OP_LIST
419
420void mlir::linalg::registerTilingInterfaceExternalModels(
421 DialectRegistry &registry) {
422 registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
423 registerOne<linalg::GenericOp>(ctx);
424 registerAll<
425#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
426 >(ctx);
427 });
428}
429

source code of mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp