1//===- ConversionUtils.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Utility functions for TOSA lowering
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
14#include "mlir/Dialect/Tosa/IR/TosaOps.h"
15
16using namespace mlir;
17using namespace mlir::tosa;
18
19SmallVector<utils::IteratorType>
20mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) {
21 return SmallVector<utils::IteratorType>(nParallelLoops,
22 utils::IteratorType::parallel);
23}
24
25SmallVector<Value>
26mlir::tosa::condenseValues(const SmallVector<Value> &values) {
27 SmallVector<Value> condensedValues;
28 for (auto value : values)
29 if (value)
30 condensedValues.push_back(Elt: value);
31 return condensedValues;
32}
33
34Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min,
35 Value max, OpBuilder &rewriter) {
36 Value minValue = rewriter.create<arith::MinimumFOp>(loc, arg, max);
37 return rewriter.create<arith::MaximumFOp>(loc, minValue, min);
38}
39
40Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max,
41 OpBuilder &rewriter) {
42 auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg);
43 return rewriter.create<arith::MinSIOp>(loc, max, minOrArg);
44}
45
46bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) {
47 uint64_t bitwidth = ty.getIntOrFloatBitWidth();
48 if (ty.getSignedness() == IntegerType::Unsigned) {
49 uint64_t uvalue = value;
50 APInt intMin = APInt::getMinValue(numBits: bitwidth);
51 APInt intMax = APInt::getMaxValue(numBits: bitwidth);
52 return uvalue >= intMin.getZExtValue() && uvalue <= intMax.getZExtValue();
53 }
54
55 APInt intMin = APInt::getSignedMinValue(numBits: bitwidth);
56 APInt intMax = APInt::getSignedMaxValue(numBits: bitwidth);
57 return value >= intMin.getSExtValue() && value <= intMax.getSExtValue();
58}
59
60namespace {
61// Given two tensors of high and low ranks, derive the output shape
62// to reshape the lower rank to.
63// Examples:
64// If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c].
65// If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c].
66// If lower=[a], higher=[a, a], [a] reshaped into [1, a].
67// If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
68// If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
69LogicalResult
70computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
71 ArrayRef<int64_t> lowerRankShape,
72 SmallVectorImpl<int64_t> &reshapeOutputShape) {
73 // Initialize new shapes with [1] * higherRank.
74 int64_t higherRank = higherRankShape.size();
75 int64_t lowerRank = lowerRankShape.size();
76
77 reshapeOutputShape.assign(NumElts: higherRank, Elt: 1);
78
79 int64_t higherRankDim;
80 int64_t lowerRankDim;
81
82 for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0;
83 i--, j--) {
84 higherRankDim = higherRankShape[i];
85 lowerRankDim = lowerRankShape[j];
86
87 if (lowerRankDim == 1 && higherRankDim > 1)
88 reshapeOutputShape[i] = 1;
89 else if ((lowerRankDim > 1 && higherRankDim == 1) ||
90 (lowerRankDim == higherRankDim))
91 reshapeOutputShape[i] = lowerRankDim;
92 else if (higherRankDim != lowerRankDim)
93 return failure();
94 }
95 return success();
96}
97} // namespace
98
99LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
100 Value &input1, Value &input2) {
101 auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType());
102 auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType());
103
104 if (!input1Ty || !input2Ty) {
105 return failure();
106 }
107
108 int64_t input1Rank = input1Ty.getRank();
109 int64_t input2Rank = input2Ty.getRank();
110
111 if (input1Rank == input2Rank)
112 return success();
113
114 Value higherTensorValue, lowerTensorValue;
115 if (input1Rank > input2Rank) {
116 higherTensorValue = input1;
117 lowerTensorValue = input2;
118 } else {
119 higherTensorValue = input2;
120 lowerTensorValue = input1;
121 }
122
123 ArrayRef<int64_t> higherRankShape =
124 llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape();
125 ArrayRef<int64_t> lowerRankShape =
126 llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape();
127
128 SmallVector<int64_t, 4> reshapeOutputShape;
129
130 if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape)
131 .failed())
132 return failure();
133
134 auto reshapeInputType =
135 llvm::cast<RankedTensorType>(lowerTensorValue.getType());
136 auto reshapeOutputType = RankedTensorType::get(
137 ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
138
139 auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
140 loc, reshapeOutputType, lowerTensorValue,
141 rewriter.getDenseI64ArrayAttr(reshapeOutputShape));
142
143 if (input1Rank > input2Rank) {
144 input1 = higherTensorValue;
145 input2 = reshapeLower.getResult();
146 } else {
147 input1 = reshapeLower.getResult();
148 input2 = higherTensorValue;
149 }
150
151 return success();
152}
153

source code of mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp