1//===- Generalization.cpp - linalg named ops to generic ops --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Linalg generalization pass. It converts named
10// Linalg ops to linalg.generic ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Linalg/Passes.h"
15
16#include "mlir/Dialect/Linalg/IR/Linalg.h"
17#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18#include "mlir/IR/AffineMap.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/PatternMatch.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22
23namespace mlir {
24#define GEN_PASS_DEF_LINALGGENERALIZENAMEDOPSPASS
25#include "mlir/Dialect/Linalg/Passes.h.inc"
26} // namespace mlir
27
28#define DEBUG_TYPE "linalg-generalization"
29
30using namespace mlir;
31using namespace mlir::linalg;
32
33static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
34 // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot
35 // trivially generalize a `linalg.map`, as it does not use the output as
36 // region arguments in the block.
37 if (isa<GenericOp>(Val: linalgOp) || isa<MapOp>(Val: linalgOp))
38 return failure();
39 // Check if the operation has exactly one region.
40 if (linalgOp->getNumRegions() != 1) {
41 assert(linalgOp->getNumRegions() == 0 && "op with multiple regions");
42 // TOD: Otherwise it needs to be built explicitly from the region builder.
43 return failure();
44 }
45 return success();
46}
47
48FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
49 LinalgOp linalgOp) {
50 if (failed(Result: generalizeNamedOpPrecondition(linalgOp)))
51 return rewriter.notifyMatchFailure(arg&: linalgOp, msg: "preconditions not met");
52
53 SmallVector<Value> inputs = linalgOp.getDpsInputs();
54 ValueRange outputs = linalgOp.getDpsInits();
55 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
56 SmallVector<utils::IteratorType> iterators = linalgOp.getIteratorTypesArray();
57 SmallVector<Type> resultTypes = linalgOp.hasPureTensorSemantics()
58 ? TypeRange(ValueRange(outputs))
59 : TypeRange{};
60
61 // All named ops have a region attached that can be inlined.
62 assert(linalgOp->getNumRegions() == 1 &&
63 "expect named op to have one region attached");
64 GenericOp genericOp = rewriter.create<GenericOp>(
65 location: linalgOp.getLoc(), args&: resultTypes, args&: inputs, args&: outputs, args&: indexingMaps, args&: iterators);
66 rewriter.inlineRegionBefore(region&: linalgOp->getRegion(index: 0), parent&: genericOp.getRegion(),
67 before: genericOp.getRegion().begin());
68 rewriter.replaceOp(op: linalgOp, newValues: genericOp->getResults());
69 return genericOp;
70}
71
72namespace {
73
74struct LinalgGeneralizeNamedOpsPass
75 : public impl::LinalgGeneralizeNamedOpsPassBase<
76 LinalgGeneralizeNamedOpsPass> {
77 using impl::LinalgGeneralizeNamedOpsPassBase<
78 LinalgGeneralizeNamedOpsPass>::LinalgGeneralizeNamedOpsPassBase;
79 void runOnOperation() override;
80};
81
82} // namespace
83
84void LinalgGeneralizeNamedOpsPass::runOnOperation() {
85 RewritePatternSet patterns(&getContext());
86 populateLinalgNamedOpsGeneralizationPatterns(patterns);
87 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
88}
89
90void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
91 RewritePatternSet &patterns) {
92 patterns.add<LinalgGeneralizationPattern>(arg: patterns.getContext());
93}
94

source code of mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp