1//===- Generalization.cpp - linalg named ops to generic ops --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Linalg generalization pass. It converts named
10// Linalg ops to linalg.generic ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Linalg/Passes.h"
15
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/Linalg/IR/Linalg.h"
18#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19#include "mlir/IR/AffineMap.h"
20#include "mlir/IR/Attributes.h"
21#include "mlir/IR/Builders.h"
22#include "mlir/IR/ImplicitLocOpBuilder.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/Support/Debug.h"
27
28namespace mlir {
29#define GEN_PASS_DEF_LINALGGENERALIZENAMEDOPSPASS
30#include "mlir/Dialect/Linalg/Passes.h.inc"
31} // namespace mlir
32
33#define DEBUG_TYPE "linalg-generalization"
34
35using namespace mlir;
36using namespace mlir::linalg;
37
38static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
39 // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot
40 // trivially generalize a `linalg.map`, as it does not use the output as
41 // region arguments in the block.
42 if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp))
43 return failure();
44 // Check if the operation has exactly one region.
45 if (linalgOp->getNumRegions() != 1) {
46 assert(linalgOp->getNumRegions() == 0 && "op with multiple regions");
47 // TOD: Otherwise it needs to be built explicitly from the region builder.
48 return failure();
49 }
50 return success();
51}
52
53FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
54 LinalgOp linalgOp) {
55 if (failed(generalizeNamedOpPrecondition(linalgOp)))
56 return rewriter.notifyMatchFailure(linalgOp, "preconditions not met");
57
58 SmallVector<Value> inputs = linalgOp.getDpsInputs();
59 ValueRange outputs = linalgOp.getDpsInits();
60 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
61 SmallVector<utils::IteratorType> iterators = linalgOp.getIteratorTypesArray();
62 SmallVector<Type> resultTypes = linalgOp.hasPureTensorSemantics()
63 ? TypeRange(ValueRange(outputs))
64 : TypeRange{};
65
66 // All named ops have a region attached that can be inlined.
67 assert(linalgOp->getNumRegions() == 1 &&
68 "expect named op to have one region attached");
69 GenericOp genericOp = rewriter.create<GenericOp>(
70 linalgOp.getLoc(), resultTypes, inputs, outputs, indexingMaps, iterators);
71 rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(),
72 genericOp.getRegion().begin());
73 rewriter.replaceOp(linalgOp, genericOp->getResults());
74 return genericOp;
75}
76
77namespace {
78
79struct LinalgGeneralizeNamedOpsPass
80 : public impl::LinalgGeneralizeNamedOpsPassBase<
81 LinalgGeneralizeNamedOpsPass> {
82 using impl::LinalgGeneralizeNamedOpsPassBase<
83 LinalgGeneralizeNamedOpsPass>::LinalgGeneralizeNamedOpsPassBase;
84 void runOnOperation() override;
85};
86
87} // namespace
88
89void LinalgGeneralizeNamedOpsPass::runOnOperation() {
90 RewritePatternSet patterns(&getContext());
91 populateLinalgNamedOpsGeneralizationPatterns(patterns);
92 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
93}
94
95void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
96 RewritePatternSet &patterns) {
97 patterns.add<LinalgGeneralizationPattern>(arg: patterns.getContext());
98}
99

source code of mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp