1 | //===- RewriteAsConstant.cpp - Patterns to rewrite tensor ops as constants ===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
10 | #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
11 | #include "mlir/IR/Matchers.h" |
12 | #include "mlir/IR/PatternMatch.h" |
13 | |
14 | using namespace mlir; |
15 | using namespace mlir::tensor; |
16 | |
17 | namespace { |
18 | |
19 | /// Rewrite tensor.generate with arith.constant if the yielded value is a |
20 | /// constant and the tensor type is static. |
21 | struct GenerateToConstant : public OpRewritePattern<GenerateOp> { |
22 | using OpRewritePattern<GenerateOp>::OpRewritePattern; |
23 | |
24 | LogicalResult matchAndRewrite(GenerateOp generateOp, |
25 | PatternRewriter &rewriter) const override { |
26 | auto tensorType = |
27 | llvm::cast<RankedTensorType>(generateOp.getResult().getType()); |
28 | if (!tensorType.hasStaticShape()) |
29 | return failure(); |
30 | auto terminatorOp = |
31 | cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator()); |
32 | Attribute attr; |
33 | if (!matchPattern(terminatorOp.getValue(), m_Constant(bind_value: &attr))) |
34 | return failure(); |
35 | Operation *constantOp = |
36 | rewriter.getContext() |
37 | ->getLoadedDialect<TensorDialect>() |
38 | ->materializeConstant(rewriter, |
39 | DenseElementsAttr::get(tensorType, attr), |
40 | tensorType, generateOp->getLoc()); |
41 | if (!constantOp) |
42 | return failure(); |
43 | rewriter.replaceOp(generateOp, constantOp->getResults()); |
44 | return success(); |
45 | } |
46 | }; |
47 | |
48 | } // namespace |
49 | |
50 | void mlir::tensor::populateRewriteAsConstantPatterns( |
51 | RewritePatternSet &patterns) { |
52 | patterns.add<GenerateToConstant>(arg: patterns.getContext()); |
53 | } |
54 | |