1 | //===- ConversionUtils.h - Helper functions for tosa conversion -*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Utility functions for TOSA lowering |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef DIALECT_TOSA_UTILS_COVERSION_UTILS_H_ |
14 | #define DIALECT_TOSA_UTILS_COVERSION_UTILS_H_ |
15 | |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
18 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
19 | #include "mlir/IR/PatternMatch.h" |
20 | #include <optional> |
21 | |
22 | namespace mlir { |
23 | namespace tosa { |
24 | |
25 | // Creates a SmallVector of Stringrefs for N parallel loops |
26 | SmallVector<utils::IteratorType> |
27 | getNParallelLoopsAttrs(unsigned nParallelLoops); |
28 | |
29 | // Takes a vector of values and condenses them to a vector with no gaps. |
30 | SmallVector<Value> condenseValues(const SmallVector<Value> &values); |
31 | |
32 | // Takes the parameters for a clamp and turns it into a series of ops for float |
33 | // inputs. |
34 | Value clampFloatHelper(Location loc, Value arg, Value min, Value max, |
35 | OpBuilder &rewriter); |
36 | |
37 | // Takes the parameters for a clamp and turns it into a series of ops for |
38 | // integer inputs. |
39 | Value clampIntHelper(Location loc, Value arg, Value min, Value max, |
40 | OpBuilder &rewriter); |
41 | |
42 | // Determines whether the integer value falls witin the range of integer type. |
43 | bool validIntegerRange(IntegerType ty, int64_t value); |
44 | |
45 | // Checks for a dynamic batch dim in any of the passed parameters of an op. |
46 | // The batch dimention must be #0 and the rest of the dimensions must be static. |
47 | template <typename Op> |
48 | std::optional<SmallVector<Value>> |
49 | checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, |
50 | ArrayRef<Value> params) { |
51 | SmallVector<ShapedType> dynTypes; |
52 | SmallVector<Value> dynamicDims; |
53 | for (const Value ¶m : params) { |
54 | auto paramTy = cast<ShapedType>(param.getType()); |
55 | if (!paramTy.hasStaticShape()) |
56 | dynTypes.push_back(paramTy); |
57 | } |
58 | |
59 | if (dynTypes.empty()) |
60 | return dynamicDims; |
61 | |
62 | for (const ShapedType &dynTy : dynTypes) { |
63 | if (llvm::any_of(dynTy.getShape().drop_front(), ShapedType::isDynamic)) { |
64 | (void)rewriter.notifyMatchFailure( |
65 | op, "input can only be dynamic for batch size" ); |
66 | return std::nullopt; |
67 | } |
68 | } |
69 | |
70 | dynamicDims.push_back( |
71 | rewriter.create<tensor::DimOp>(op->getLoc(), params[0], 0)); |
72 | return dynamicDims; |
73 | } |
74 | |
75 | /// Common code to create the reshape op where necessary to make the rank of two |
76 | /// values equal. input1 and input2 will be updated when the rank has |
77 | /// changed. The caller is expected to use these to rewrite the original |
78 | /// operator with the RESHAPE now in the graph. |
79 | LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, |
80 | Value &input1, Value &input2); |
81 | |
82 | } // namespace tosa |
83 | } // namespace mlir |
84 | |
85 | #endif // DIALECT_TOSA_UTILS_COVERSION_UTILS_H_ |
86 | |