1//===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/Utils/Utils.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Complex/IR/Complex.h"
16#include "mlir/IR/ImplicitLocOpBuilder.h"
17#include "llvm/ADT/SmallBitVector.h"
18#include <numeric>
19
20using namespace mlir;
21
22/// Matches a ConstantIndexOp.
23/// TODO: This should probably just be a general matcher that uses matchConstant
24/// and checks the operation for an index type.
25detail::op_matcher<arith::ConstantIndexOp> mlir::matchConstantIndex() {
26 return detail::op_matcher<arith::ConstantIndexOp>();
27}
28
29llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
30 ArrayRef<int64_t> shape) {
31 llvm::SmallBitVector dimsToProject(shape.size());
32 for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
33 if (shape[pos] == 1) {
34 dimsToProject.set(pos);
35 --rank;
36 }
37 }
38 return dimsToProject;
39}
40
41Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
42 OpFoldResult ofr) {
43 if (auto value = llvm::dyn_cast_if_present<Value>(Val&: ofr))
44 return value;
45 auto attr = dyn_cast<IntegerAttr>(llvm::dyn_cast_if_present<Attribute>(Val&: ofr));
46 assert(attr && "expect the op fold result casts to an integer attribute");
47 return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
48}
49
50Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
51 Type targetType, Value value) {
52 if (targetType == value.getType())
53 return value;
54
55 bool targetIsIndex = targetType.isIndex();
56 bool valueIsIndex = value.getType().isIndex();
57 if (targetIsIndex ^ valueIsIndex)
58 return b.create<arith::IndexCastOp>(loc, targetType, value);
59
60 auto targetIntegerType = dyn_cast<IntegerType>(targetType);
61 auto valueIntegerType = dyn_cast<IntegerType>(value.getType());
62 assert(targetIntegerType && valueIntegerType &&
63 "unexpected cast between types other than integers and index");
64 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
65
66 if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
67 return b.create<arith::ExtSIOp>(loc, targetIntegerType, value);
68 return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
69}
70
71static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand,
72 IntegerType toType, bool isUnsigned) {
73 // If operand is floating point, cast directly to the int type.
74 if (isa<FloatType>(Val: operand.getType())) {
75 if (isUnsigned)
76 return b.create<arith::FPToUIOp>(toType, operand);
77 return b.create<arith::FPToSIOp>(toType, operand);
78 }
79 // Cast index operands directly to the int type.
80 if (operand.getType().isIndex())
81 return b.create<arith::IndexCastOp>(toType, operand);
82 if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) {
83 // Either extend or truncate.
84 if (toType.getWidth() > fromIntType.getWidth()) {
85 if (isUnsigned)
86 return b.create<arith::ExtUIOp>(toType, operand);
87 return b.create<arith::ExtSIOp>(toType, operand);
88 }
89 if (toType.getWidth() < fromIntType.getWidth())
90 return b.create<arith::TruncIOp>(toType, operand);
91 return operand;
92 }
93
94 return {};
95}
96
97static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand,
98 FloatType toType, bool isUnsigned) {
99 // If operand is integer, cast directly to the float type.
100 // Note that it is unclear how to cast from BF16<->FP16.
101 if (isa<IntegerType>(Val: operand.getType())) {
102 if (isUnsigned)
103 return b.create<arith::UIToFPOp>(toType, operand);
104 return b.create<arith::SIToFPOp>(toType, operand);
105 }
106 if (auto fromFpTy = dyn_cast<FloatType>(Val: operand.getType())) {
107 if (toType.getWidth() > fromFpTy.getWidth())
108 return b.create<arith::ExtFOp>(toType, operand);
109 if (toType.getWidth() < fromFpTy.getWidth())
110 return b.create<arith::TruncFOp>(toType, operand);
111 return operand;
112 }
113
114 return {};
115}
116
117static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand,
118 ComplexType targetType,
119 bool isUnsigned) {
120 if (auto fromComplexType = dyn_cast<ComplexType>(operand.getType())) {
121 if (isa<FloatType>(targetType.getElementType()) &&
122 isa<FloatType>(fromComplexType.getElementType())) {
123 Value real = b.create<complex::ReOp>(operand);
124 Value imag = b.create<complex::ImOp>(operand);
125 Type targetETy = targetType.getElementType();
126 if (targetType.getElementType().getIntOrFloatBitWidth() <
127 fromComplexType.getElementType().getIntOrFloatBitWidth()) {
128 real = b.create<arith::TruncFOp>(targetETy, real);
129 imag = b.create<arith::TruncFOp>(targetETy, imag);
130 } else {
131 real = b.create<arith::ExtFOp>(targetETy, real);
132 imag = b.create<arith::ExtFOp>(targetETy, imag);
133 }
134 return b.create<complex::CreateOp>(targetType, real, imag);
135 }
136 }
137
138 if (dyn_cast<FloatType>(Val: operand.getType())) {
139 FloatType toFpTy = cast<FloatType>(targetType.getElementType());
140 auto toBitwidth = toFpTy.getIntOrFloatBitWidth();
141 Value from = operand;
142 if (from.getType().getIntOrFloatBitWidth() < toBitwidth) {
143 from = b.create<arith::ExtFOp>(toFpTy, from);
144 }
145 if (from.getType().getIntOrFloatBitWidth() > toBitwidth) {
146 from = b.create<arith::TruncFOp>(toFpTy, from);
147 }
148 Value zero = b.create<mlir::arith::ConstantFloatOp>(
149 args: mlir::APFloat(toFpTy.getFloatSemantics(), 0), args&: toFpTy);
150 return b.create<complex::CreateOp>(targetType, from, zero);
151 }
152
153 if (dyn_cast<IntegerType>(operand.getType())) {
154 FloatType toFpTy = cast<FloatType>(targetType.getElementType());
155 Value from = operand;
156 if (isUnsigned) {
157 from = b.create<arith::UIToFPOp>(toFpTy, from);
158 } else {
159 from = b.create<arith::SIToFPOp>(toFpTy, from);
160 }
161 Value zero = b.create<mlir::arith::ConstantFloatOp>(
162 args: mlir::APFloat(toFpTy.getFloatSemantics(), 0), args&: toFpTy);
163 return b.create<complex::CreateOp>(targetType, from, zero);
164 }
165
166 return {};
167}
168
169Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
170 Type toType, bool isUnsignedCast) {
171 if (operand.getType() == toType)
172 return operand;
173 ImplicitLocOpBuilder ib(loc, b);
174 Value result;
175 if (auto intTy = dyn_cast<IntegerType>(toType)) {
176 result = convertScalarToIntDtype(ib, operand, intTy, isUnsignedCast);
177 } else if (auto floatTy = dyn_cast<FloatType>(Val&: toType)) {
178 result = convertScalarToFpDtype(b&: ib, operand, toType: floatTy, isUnsigned: isUnsignedCast);
179 } else if (auto complexTy = dyn_cast<ComplexType>(toType)) {
180 result =
181 convertScalarToComplexDtype(ib, operand, complexTy, isUnsignedCast);
182 }
183
184 if (result)
185 return result;
186
187 emitWarning(loc) << "could not cast operand of type " << operand.getType()
188 << " to " << toType;
189 return operand;
190}
191
192SmallVector<Value>
193mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
194 ArrayRef<OpFoldResult> valueOrAttrVec) {
195 return llvm::to_vector<4>(
196 Range: llvm::map_range(C&: valueOrAttrVec, F: [&](OpFoldResult value) -> Value {
197 return getValueOrCreateConstantIndexOp(b, loc, ofr: value);
198 }));
199}
200
201Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
202 Type type, const APInt &value) {
203 TypedAttr attr;
204 if (isa<IntegerType>(Val: type)) {
205 attr = builder.getIntegerAttr(type, value);
206 } else {
207 auto vecTy = cast<ShapedType>(type);
208 attr = SplatElementsAttr::get(vecTy, value);
209 }
210
211 return builder.create<arith::ConstantOp>(loc, attr);
212}
213
214Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
215 Type type, int64_t value) {
216 unsigned elementBitWidth = 0;
217 if (auto intTy = dyn_cast<IntegerType>(type))
218 elementBitWidth = intTy.getWidth();
219 else
220 elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth();
221
222 return createScalarOrSplatConstant(builder, loc, type,
223 value: APInt(elementBitWidth, value));
224}
225
226Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
227 Type type, const APFloat &value) {
228 if (isa<FloatType>(type))
229 return builder.createOrFold<arith::ConstantOp>(
230 loc, type, builder.getFloatAttr(type, value));
231 TypedAttr splat = SplatElementsAttr::get(cast<ShapedType>(type), value);
232 return builder.createOrFold<arith::ConstantOp>(loc, type, splat);
233}
234
235Value ArithBuilder::_and(Value lhs, Value rhs) {
236 return b.create<arith::AndIOp>(loc, lhs, rhs);
237}
238Value ArithBuilder::add(Value lhs, Value rhs) {
239 if (isa<FloatType>(lhs.getType()))
240 return b.create<arith::AddFOp>(loc, lhs, rhs);
241 return b.create<arith::AddIOp>(loc, lhs, rhs);
242}
243Value ArithBuilder::sub(Value lhs, Value rhs) {
244 if (isa<FloatType>(lhs.getType()))
245 return b.create<arith::SubFOp>(loc, lhs, rhs);
246 return b.create<arith::SubIOp>(loc, lhs, rhs);
247}
248Value ArithBuilder::mul(Value lhs, Value rhs) {
249 if (isa<FloatType>(lhs.getType()))
250 return b.create<arith::MulFOp>(loc, lhs, rhs);
251 return b.create<arith::MulIOp>(loc, lhs, rhs);
252}
253Value ArithBuilder::sgt(Value lhs, Value rhs) {
254 if (isa<FloatType>(lhs.getType()))
255 return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
256 return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs);
257}
258Value ArithBuilder::slt(Value lhs, Value rhs) {
259 if (isa<FloatType>(lhs.getType()))
260 return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
261 return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs);
262}
263Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
264 return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
265}
266
267namespace mlir::arith {
268
269Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values) {
270 return createProduct(builder, loc, values, resultType: values.front().getType());
271}
272
273Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
274 Type resultType) {
275 Value one = builder.create<ConstantOp>(loc, resultType,
276 builder.getOneAttr(resultType));
277 ArithBuilder arithBuilder(builder, loc);
278 return std::accumulate(
279 first: values.begin(), last: values.end(), init: one,
280 binary_op: [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(lhs: acc, rhs: v); });
281}
282
283} // namespace mlir::arith
284

source code of mlir/lib/Dialect/Arith/Utils/Utils.cpp