1 | //===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" |
10 | |
11 | #include "mlir/Analysis/SliceAnalysis.h" |
12 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
14 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
15 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
16 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
18 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
19 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
20 | #include "mlir/Interfaces/TilingInterface.h" |
21 | #include <optional> |
22 | |
23 | using namespace mlir; |
24 | using namespace mlir::linalg; |
25 | |
26 | //===----------------------------------------------------------------------===// |
27 | // Utility methods for implementation of Tiling Interface for Linalg ops |
28 | //===----------------------------------------------------------------------===// |
29 | |
30 | /// Return the SSA values that represent the data point accessed using a given |
31 | /// `indexingMap` for a given point in the iteration space represented by `ivs`. |
32 | static SmallVector<Value> getIndicesForAccess(OpBuilder &b, Location loc, |
33 | AffineMap indexingMap, |
34 | ValueRange ivs) { |
35 | SmallVector<Value> indices; |
36 | indices.reserve(N: indexingMap.getNumResults()); |
37 | for (auto result : indexingMap.getResults()) { |
38 | AffineMap m = AffineMap::get(dimCount: indexingMap.getNumDims(), |
39 | symbolCount: indexingMap.getNumSymbols(), result); |
40 | Value v = b.create<affine::AffineApplyOp>(loc, m, ivs); |
41 | indices.push_back(Elt: v); |
42 | } |
43 | return indices; |
44 | } |
45 | |
46 | /// Method to inline the payload of a `linalgOp` given the iteration space |
47 | /// point and values for the arguments of the payload. |
48 | static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp, |
49 | ValueRange ivs, ValueRange argValues) { |
50 | Block *body = linalgOp.getBlock(); |
51 | IRMapping map; |
52 | map.map(from: body->getArguments(), to&: argValues); |
53 | for (auto &op : body->without_terminator()) { |
54 | if (auto indexOp = dyn_cast<IndexOp>(&op)) { |
55 | map.map(indexOp.getResult(), ivs[indexOp.getDim()]); |
56 | continue; |
57 | } |
58 | b.clone(op, map); |
59 | } |
60 | |
61 | Operation *terminator = body->getTerminator(); |
62 | Location loc = terminator->getLoc(); |
63 | for (const auto &operand : llvm::enumerate(terminator->getOperands())) { |
64 | Value toStore = map.lookupOrDefault(operand.value()); |
65 | OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index()); |
66 | auto indices = getIndicesForAccess( |
67 | b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs); |
68 | b.create<memref::StoreOp>( |
69 | loc, toStore, linalgOp.getDpsInitOperand(operand.index())->get(), |
70 | indices); |
71 | } |
72 | return success(); |
73 | } |
74 | |
75 | //===----------------------------------------------------------------------===// |
76 | // External Model for implementing `TilingInterface` for `LinalgOp`s. |
77 | //===----------------------------------------------------------------------===// |
78 | |
79 | namespace { |
80 | /// External model implementation of TilingInterface for LinalgOps. An external |
81 | /// model implementation is used for now till the use of `TilingInterface` is |
82 | /// on-par with the current Linalg tiling + fusion patterns. Once it is |
83 | /// maybe possible to move this into the op-definition (though there are |
84 | /// advantages to leaving it as an external model) |
85 | template <typename LinalgOpTy> |
86 | struct LinalgOpTilingInterface |
87 | : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>, |
88 | LinalgOpTy> { |
89 | /// Return the loop iterator type. |
90 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
91 | LinalgOpTy concreteOp = cast<LinalgOpTy>(op); |
92 | return concreteOp.getIteratorTypesArray(); |
93 | } |
94 | |
95 | /// Return the iteration domain range. |
96 | SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { |
97 | OpBuilder::InsertionGuard g(b); |
98 | b.setInsertionPoint(op); |
99 | Location loc = op->getLoc(); |
100 | LinalgOp linalgOp = cast<LinalgOp>(op); |
101 | SmallVector<OpFoldResult> allShapesSizes = |
102 | linalgOp.createFlatListOfOperandDims(b, loc); |
103 | AffineMap map = linalgOp.getShapesToLoopsMap(); |
104 | |
105 | return llvm::to_vector( |
106 | llvm::map_range(map.getResults(), [&](AffineExpr loopExpr) { |
107 | OpFoldResult ofr = affine::makeComposedFoldedAffineApply( |
108 | b, loc, expr: loopExpr, operands: allShapesSizes); |
109 | return Range{b.getIndexAttr(0), .size: ofr, b.getIndexAttr(1)}; |
110 | })); |
111 | } |
112 | |
113 | // Instantiate the tiled implementation of the operation. |
114 | FailureOr<TilingResult> |
115 | getTiledImplementation(Operation *op, OpBuilder &b, |
116 | ArrayRef<OpFoldResult> offsets, |
117 | ArrayRef<OpFoldResult> sizes) const { |
118 | // Leave the `sizeBounds` value empty. That is only needed when the `sizes` |
119 | // specified could lead to out of bounds accesses. |
120 | Location loc = op->getLoc(); |
121 | LinalgOp linalgOp = cast<LinalgOp>(op); |
122 | SmallVector<Value> valuesToTile = linalgOp->getOperands(); |
123 | SmallVector<Value, 4> tiledOperands = makeTiledShapes( |
124 | b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true); |
125 | |
126 | SmallVector<Type> resultTensorTypes = |
127 | getTensorOutputTypes(linalgOp, tiledOperands); |
128 | |
129 | Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands); |
130 | offsetIndices(b, cast<LinalgOp>(tiledOp), offsets); |
131 | |
132 | return TilingResult{.tiledOps: {tiledOp}, .tiledValues: SmallVector<Value>(tiledOp->getResults())}; |
133 | } |
134 | |
135 | // Return the details of the output tile generated by the tiled |
136 | // implementation. |
137 | LogicalResult |
138 | getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, |
139 | ArrayRef<OpFoldResult> offsets, |
140 | ArrayRef<OpFoldResult> sizes, |
141 | SmallVector<OpFoldResult> &resultOffsets, |
142 | SmallVector<OpFoldResult> &resultSizes) const { |
143 | Location loc = op->getLoc(); |
144 | LinalgOp linalgOp = cast<LinalgOp>(op); |
145 | |
146 | AffineExpr d0; |
147 | bindDims(ctx: b.getContext(), exprs&: d0); |
148 | SmallVector<OpFoldResult> subShapeSizes = |
149 | llvm::to_vector(llvm::map_range(sizes, [&](OpFoldResult ofr) { |
150 | return affine::makeComposedFoldedAffineApply(b, loc, expr: d0 - 1, operands: ofr); |
151 | })); |
152 | |
153 | OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber); |
154 | SliceParameters sliceParams = computeSliceParameters( |
155 | b, loc, outOperand->get(), sizes, |
156 | linalgOp.getMatchingIndexingMap(outOperand), offsets, |
157 | /*ubs*/ {}, subShapeSizes, true); |
158 | resultOffsets = sliceParams.offsets; |
159 | resultSizes = sliceParams.sizes; |
160 | return success(); |
161 | } |
162 | |
163 | FailureOr<TilingResult> |
164 | generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, |
165 | ArrayRef<OpFoldResult> offsets, |
166 | ArrayRef<OpFoldResult> sizes) const { |
167 | auto linalgOp = cast<LinalgOp>(op); |
168 | |
169 | // Check that the indexing map used for the output is a projected |
170 | // permutation. This could be relaxed with a more general approach that can |
171 | // map the offsets and sizes from the result to iteration space tiles |
172 | // (filling in full extent for dimensions not used to access the result). |
173 | AffineMap indexingMap = |
174 | linalgOp.getIndexingMapMatchingResult(op->getResult(idx: resultNumber)); |
175 | if (!indexingMap.isProjectedPermutation()) { |
176 | return op->emitOpError( |
177 | message: "unhandled tiled implementation generation when result is not " |
178 | "accessed using a permuted projection" ); |
179 | } |
180 | |
181 | auto numLoops = linalgOp.getNumLoops(); |
182 | auto tilingInterfaceOp = cast<TilingInterface>(op); |
183 | SmallVector<OpFoldResult> iterationTileOffsets(numLoops), |
184 | iterationTileSizes(numLoops); |
185 | if (!indexingMap.isPermutation()) { |
186 | SmallVector<Range> iterationDomain = |
187 | tilingInterfaceOp.getIterationDomain(b); |
188 | for (const auto &range : llvm::enumerate(iterationDomain)) { |
189 | iterationTileOffsets[range.index()] = range.value().offset; |
190 | iterationTileSizes[range.index()] = range.value().size; |
191 | } |
192 | } |
193 | for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) { |
194 | unsigned dimPosition = |
195 | cast<AffineDimExpr>(resultExpr.value()).getPosition(); |
196 | iterationTileOffsets[dimPosition] = offsets[resultExpr.index()]; |
197 | iterationTileSizes[dimPosition] = sizes[resultExpr.index()]; |
198 | } |
199 | |
200 | FailureOr<TilingResult> tilingResult = |
201 | tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets, |
202 | iterationTileSizes); |
203 | if (tilingResult->tiledOps.size() != 1) |
204 | return op->emitOpError(message: "failed to generate tiled implementation" ); |
205 | |
206 | return TilingResult{ |
207 | .tiledOps: tilingResult->tiledOps, |
208 | .tiledValues: SmallVector<Value>{tilingResult->tiledValues[resultNumber]}}; |
209 | } |
210 | |
211 | LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder, |
212 | Location loc, |
213 | ValueRange ivs) const { |
214 | auto linalgOp = cast<LinalgOp>(op); |
215 | if (!linalgOp.hasPureBufferSemantics()) |
216 | return op->emitOpError(message: "expected operation to have buffer semantics" ); |
217 | |
218 | SmallVector<Value> indexedValues; |
219 | indexedValues.reserve(N: linalgOp->getNumOperands()); |
220 | Location linalgOpLoc = op->getLoc(); |
221 | /// Load the data corresponding to the block arguments that |
222 | /// represent input operands. |
223 | for (OpOperand &operand : linalgOp->getOpOperands()) { |
224 | if (!linalgOp.payloadUsesValueFromOperand(&operand)) { |
225 | indexedValues.push_back(nullptr); |
226 | continue; |
227 | } |
228 | if (linalgOp.isScalar(&operand)) { |
229 | indexedValues.push_back(operand.get()); |
230 | continue; |
231 | } |
232 | SmallVector<Value> indices = getIndicesForAccess( |
233 | builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs); |
234 | Value load = |
235 | builder.create<memref::LoadOp>(linalgOpLoc, operand.get(), indices); |
236 | indexedValues.push_back(load); |
237 | } |
238 | |
239 | /// Inline the op payload and store the result. |
240 | return inlinePayload(builder, linalgOp, ivs, indexedValues); |
241 | } |
242 | }; |
243 | |
244 | //===----------------------------------------------------------------------===// |
245 | // External Model for implementing `PartialReductionInterface` for `LinalgOp`s. |
246 | //===----------------------------------------------------------------------===// |
247 | |
248 | /// External model implementation of PartialReductionInterface for LinalgOps. |
249 | template <typename LinalgOpTy> |
250 | struct LinalgOpPartialReductionInterface |
251 | : public PartialReductionOpInterface::ExternalModel< |
252 | LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> { |
253 | FailureOr<Operation *> generateInitialTensorForPartialReduction( |
254 | Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes, |
255 | ArrayRef<int> reductionDims) const { |
256 | auto linalgOp = cast<LinalgOp>(op); |
257 | OpBuilder::InsertionGuard guard(b); |
258 | |
259 | if (linalgOp.hasPureBufferSemantics()) |
260 | return op->emitOpError(message: "expected operation to have tensor semantics" ); |
261 | // Insert the new parallel dimension based on the index of the reduction |
262 | // loops. This could be controlled by user for more flexibility. |
263 | |
264 | SmallVector<Operation *, 4> combinerOps; |
265 | if (!matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps) || |
266 | combinerOps.size() != 1) |
267 | return op->emitOpError(message: "Failed to anaysis the reduction operation." ); |
268 | |
269 | Operation *reductionOp = combinerOps[0]; |
270 | std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp); |
271 | if (!identity.has_value()) |
272 | return op->emitOpError( |
273 | message: "Failed to get an identity value for the reduction operation." ); |
274 | |
275 | ArrayRef<int64_t> oldShape = |
276 | linalgOp.getShape(linalgOp.getDpsInitOperand(0)); |
277 | |
278 | // Calculate the new shape, we insert the new dimensions based on the index |
279 | // of the reduction dimensions. |
280 | SmallVector<int64_t> newOutputShape; |
281 | SmallVector<Value> dynamicDims; |
282 | int64_t currReductionDims = 0; |
283 | DenseSet<int> reductionDimsSet(reductionDims.begin(), reductionDims.end()); |
284 | for (int64_t idx : |
285 | llvm::seq<int64_t>(0, oldShape.size() + reductionDims.size())) { |
286 | if (reductionDimsSet.contains(idx)) { |
287 | dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape); |
288 | currReductionDims++; |
289 | continue; |
290 | } |
291 | int64_t oldIdx = idx - currReductionDims; |
292 | int64_t dim = oldShape[oldIdx]; |
293 | newOutputShape.push_back(dim); |
294 | if (ShapedType::isDynamic(dim)) |
295 | dynamicDims.push_back(b.create<tensor::DimOp>( |
296 | loc, linalgOp.getDpsInitOperand(0)->get(), oldIdx)); |
297 | } |
298 | Value emptyTensor = b.create<tensor::EmptyOp>( |
299 | loc, newOutputShape, linalgOp.getRegionOutputArgs()[0].getType(), |
300 | dynamicDims); |
301 | Value constantOp = b.create<arith::ConstantOp>(loc, *identity); |
302 | auto identityTensor = |
303 | b.create<linalg::FillOp>(loc, constantOp, emptyTensor); |
304 | return identityTensor.getOperation(); |
305 | } |
306 | |
307 | Operation *tileToPartialReduction(Operation *op, OpBuilder &b, Location loc, |
308 | ValueRange init, |
309 | ArrayRef<OpFoldResult> offsets, |
310 | ArrayRef<OpFoldResult> sizes, |
311 | ArrayRef<int> reductionDims) const { |
312 | OpBuilder::InsertionGuard guard(b); |
313 | auto linalgOp = cast<LinalgOp>(op); |
314 | |
315 | AffineMap oldOutputMap = |
316 | linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0)); |
317 | SmallVector<AffineExpr> outputExpr(oldOutputMap.getNumResults() + |
318 | reductionDims.size()); |
319 | |
320 | for (int idx : reductionDims) |
321 | outputExpr[idx] = b.getAffineDimExpr(position: idx); |
322 | int currExpr = 0; |
323 | for (int idx : llvm::seq<int>(0, outputExpr.size())) { |
324 | if (outputExpr[idx]) |
325 | continue; |
326 | outputExpr[idx] = oldOutputMap.getResult(currExpr++); |
327 | } |
328 | |
329 | // Step 1: Extract a slice of the input operands. |
330 | SmallVector<Value> valuesToTile = linalgOp.getDpsInputs(); |
331 | SmallVector<Value, 4> tiledOperands = makeTiledShapes( |
332 | b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true); |
333 | |
334 | // Step 2: Extract the accumulator operands |
335 | SmallVector<OpFoldResult> strides(offsets.size(), b.getIndexAttr(1)); |
336 | SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0)); |
337 | // TODO: use SubsetExtractOpInterface once it is available. |
338 | Value out = b.create<tensor::ExtractSliceOp>(loc, init[0], outOffsets, |
339 | sizes, strides); |
340 | |
341 | // Step3. Create a generic op where the reduction dimensions are replaced |
342 | // by a parallel dimension of the size of reduction. |
343 | SmallVector<utils::IteratorType> newIteratorTypes = |
344 | linalgOp.getIteratorTypesArray(); |
345 | for (int dim : reductionDims) |
346 | newIteratorTypes[dim] = utils::IteratorType::parallel; |
347 | SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray(); |
348 | newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr, |
349 | linalgOp.getContext()); |
350 | auto genericOp = |
351 | b.create<GenericOp>(loc, TypeRange({out.getType()}), tiledOperands, |
352 | ValueRange({out}), newMaps, newIteratorTypes); |
353 | IRMapping mapping; |
354 | op->getRegion(index: 0).cloneInto(&genericOp.getRegion(), |
355 | genericOp.getRegion().begin(), mapping); |
356 | return genericOp.getOperation(); |
357 | } |
358 | |
359 | Operation *mergeReductions(Operation *op, OpBuilder &b, Location loc, |
360 | ValueRange partialReduce, |
361 | ArrayRef<int> reductionDims) const { |
362 | auto linalgOp = cast<LinalgOp>(op); |
363 | |
364 | DenseSet<int> reductionDimsSet(reductionDims.begin(), reductionDims.end()); |
365 | |
366 | // Then create a new reduction that only reduce the newly added dimensions |
367 | // from the previous op. |
368 | int64_t intermRank = cast<ShapedType>(partialReduce[0].getType()).getRank(); |
369 | AffineMap inputMap = b.getMultiDimIdentityMap(rank: intermRank); |
370 | SmallVector<utils::IteratorType> reductionIteratorTypes; |
371 | SmallVector<AffineExpr> exprs; |
372 | |
373 | for (int64_t i : llvm::seq<int64_t>(Begin: 0, End: intermRank)) { |
374 | if (reductionDimsSet.contains(V: i)) { |
375 | reductionIteratorTypes.push_back(utils::IteratorType::reduction); |
376 | } else { |
377 | exprs.push_back(Elt: b.getAffineDimExpr(position: i)); |
378 | reductionIteratorTypes.push_back(utils::IteratorType::parallel); |
379 | } |
380 | } |
381 | |
382 | AffineMap outputMap = |
383 | AffineMap::get(dimCount: intermRank, symbolCount: 0, results: exprs, context: op->getContext()); |
384 | SmallVector<AffineMap> reductionMaps = {inputMap, outputMap}; |
385 | |
386 | SmallVector<Operation *, 4> combinerOps; |
387 | matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps); |
388 | Operation *reductionOp = combinerOps[0]; |
389 | |
390 | auto reduction = b.create<GenericOp>( |
391 | loc, op->getResultTypes(), ValueRange({partialReduce[0]}), |
392 | linalgOp.getDpsInits(), reductionMaps, reductionIteratorTypes, |
393 | [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { |
394 | Operation *clonedReductionOp = b.clone(*reductionOp); |
395 | clonedReductionOp->setOperand(0, inputs[0]); |
396 | clonedReductionOp->setOperand(1, inputs[1]); |
397 | b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); |
398 | }); |
399 | return reduction.getOperation(); |
400 | } |
401 | }; |
402 | |
403 | } // namespace |
404 | |
405 | template <typename OpType> |
406 | static void registerOne(MLIRContext *ctx) { |
407 | OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx); |
408 | OpType::template attachInterface<LinalgOpPartialReductionInterface<OpType>>( |
409 | *ctx); |
410 | } |
411 | |
412 | /// Variadic helper function. |
413 | template <typename... OpTypes> |
414 | static void registerAll(MLIRContext *ctx) { |
415 | (registerOne<OpTypes>(ctx), ...); |
416 | } |
417 | |
418 | #define GET_OP_LIST |
419 | |
420 | void mlir::linalg::registerTilingInterfaceExternalModels( |
421 | DialectRegistry ®istry) { |
422 | registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { |
423 | registerOne<linalg::GenericOp>(ctx); |
424 | registerAll< |
425 | #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" |
426 | >(ctx); |
427 | }); |
428 | } |
429 | |