1//===- RewriteAsConstant.cpp - Patterns to rewrite tensor ops as constants ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9#include "mlir/Dialect/Tensor/IR/Tensor.h"
10#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
11#include "mlir/IR/Matchers.h"
12#include "mlir/IR/PatternMatch.h"
13
14using namespace mlir;
15using namespace mlir::tensor;
16
17namespace {
18
19/// Rewrite tensor.generate with arith.constant if the yielded value is a
20/// constant and the tensor type is static.
21struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
22 using OpRewritePattern<GenerateOp>::OpRewritePattern;
23
24 LogicalResult matchAndRewrite(GenerateOp generateOp,
25 PatternRewriter &rewriter) const override {
26 auto tensorType =
27 llvm::cast<RankedTensorType>(generateOp.getResult().getType());
28 if (!tensorType.hasStaticShape())
29 return failure();
30 auto terminatorOp =
31 cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator());
32 Attribute attr;
33 if (!matchPattern(terminatorOp.getValue(), m_Constant(bind_value: &attr)))
34 return failure();
35 Operation *constantOp =
36 rewriter.getContext()
37 ->getLoadedDialect<TensorDialect>()
38 ->materializeConstant(rewriter,
39 DenseElementsAttr::get(tensorType, attr),
40 tensorType, generateOp->getLoc());
41 if (!constantOp)
42 return failure();
43 rewriter.replaceOp(generateOp, constantOp->getResults());
44 return success();
45 }
46};
47
48} // namespace
49
50void mlir::tensor::populateRewriteAsConstantPatterns(
51 RewritePatternSet &patterns) {
52 patterns.add<GenerateToConstant>(arg: patterns.getContext());
53}
54

source code of mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp