1//===- ConversionUtils.h - Helper functions for tosa conversion -*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Utility functions for TOSA lowering
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
14#define DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
15
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Tensor/IR/Tensor.h"
18#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
19#include "mlir/IR/PatternMatch.h"
20#include <optional>
21
22namespace mlir {
23namespace tosa {
24
25// Creates a SmallVector of Stringrefs for N parallel loops
26SmallVector<utils::IteratorType>
27getNParallelLoopsAttrs(unsigned nParallelLoops);
28
29// Takes a vector of values and condenses them to a vector with no gaps.
30SmallVector<Value> condenseValues(const SmallVector<Value> &values);
31
32// Takes the parameters for a clamp and turns it into a series of ops for float
33// inputs.
34Value clampFloatHelper(Location loc, Value arg, Value min, Value max,
35 OpBuilder &rewriter);
36
37// Takes the parameters for a clamp and turns it into a series of ops for
38// integer inputs.
39Value clampIntHelper(Location loc, Value arg, Value min, Value max,
40 OpBuilder &rewriter);
41
42// Determines whether the integer value falls witin the range of integer type.
43bool validIntegerRange(IntegerType ty, int64_t value);
44
45// Checks for a dynamic batch dim in any of the passed parameters of an op.
46// The batch dimention must be #0 and the rest of the dimensions must be static.
47template <typename Op>
48std::optional<SmallVector<Value>>
49checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op,
50 ArrayRef<Value> params) {
51 SmallVector<ShapedType> dynTypes;
52 SmallVector<Value> dynamicDims;
53 for (const Value &param : params) {
54 auto paramTy = cast<ShapedType>(param.getType());
55 if (!paramTy.hasStaticShape())
56 dynTypes.push_back(paramTy);
57 }
58
59 if (dynTypes.empty())
60 return dynamicDims;
61
62 for (const ShapedType &dynTy : dynTypes) {
63 if (llvm::any_of(dynTy.getShape().drop_front(), ShapedType::isDynamic)) {
64 (void)rewriter.notifyMatchFailure(
65 op, "input can only be dynamic for batch size");
66 return std::nullopt;
67 }
68 }
69
70 dynamicDims.push_back(
71 rewriter.create<tensor::DimOp>(op->getLoc(), params[0], 0));
72 return dynamicDims;
73}
74
75/// Common code to create the reshape op where necessary to make the rank of two
76/// values equal. input1 and input2 will be updated when the rank has
77/// changed. The caller is expected to use these to rewrite the original
78/// operator with the RESHAPE now in the graph.
79LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc,
80 Value &input1, Value &input2);
81
82} // namespace tosa
83} // namespace mlir
84
85#endif // DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
86

source code of mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h