1 | //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements utilities for the Linalg dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/Complex/IR/Complex.h" |
16 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
17 | #include "llvm/ADT/SmallBitVector.h" |
18 | #include <numeric> |
19 | |
20 | using namespace mlir; |
21 | |
22 | /// Matches a ConstantIndexOp. |
23 | /// TODO: This should probably just be a general matcher that uses matchConstant |
24 | /// and checks the operation for an index type. |
25 | detail::op_matcher<arith::ConstantIndexOp> mlir::matchConstantIndex() { |
26 | return detail::op_matcher<arith::ConstantIndexOp>(); |
27 | } |
28 | |
29 | llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank, |
30 | ArrayRef<int64_t> shape) { |
31 | llvm::SmallBitVector dimsToProject(shape.size()); |
32 | for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) { |
33 | if (shape[pos] == 1) { |
34 | dimsToProject.set(pos); |
35 | --rank; |
36 | } |
37 | } |
38 | return dimsToProject; |
39 | } |
40 | |
41 | Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, |
42 | OpFoldResult ofr) { |
43 | if (auto value = llvm::dyn_cast_if_present<Value>(Val&: ofr)) |
44 | return value; |
45 | auto attr = dyn_cast<IntegerAttr>(llvm::dyn_cast_if_present<Attribute>(Val&: ofr)); |
46 | assert(attr && "expect the op fold result casts to an integer attribute" ); |
47 | return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue()); |
48 | } |
49 | |
50 | Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, |
51 | Type targetType, Value value) { |
52 | if (targetType == value.getType()) |
53 | return value; |
54 | |
55 | bool targetIsIndex = targetType.isIndex(); |
56 | bool valueIsIndex = value.getType().isIndex(); |
57 | if (targetIsIndex ^ valueIsIndex) |
58 | return b.create<arith::IndexCastOp>(loc, targetType, value); |
59 | |
60 | auto targetIntegerType = dyn_cast<IntegerType>(targetType); |
61 | auto valueIntegerType = dyn_cast<IntegerType>(value.getType()); |
62 | assert(targetIntegerType && valueIntegerType && |
63 | "unexpected cast between types other than integers and index" ); |
64 | assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); |
65 | |
66 | if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) |
67 | return b.create<arith::ExtSIOp>(loc, targetIntegerType, value); |
68 | return b.create<arith::TruncIOp>(loc, targetIntegerType, value); |
69 | } |
70 | |
71 | static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, |
72 | IntegerType toType, bool isUnsigned) { |
73 | // If operand is floating point, cast directly to the int type. |
74 | if (isa<FloatType>(Val: operand.getType())) { |
75 | if (isUnsigned) |
76 | return b.create<arith::FPToUIOp>(toType, operand); |
77 | return b.create<arith::FPToSIOp>(toType, operand); |
78 | } |
79 | // Cast index operands directly to the int type. |
80 | if (operand.getType().isIndex()) |
81 | return b.create<arith::IndexCastOp>(toType, operand); |
82 | if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) { |
83 | // Either extend or truncate. |
84 | if (toType.getWidth() > fromIntType.getWidth()) { |
85 | if (isUnsigned) |
86 | return b.create<arith::ExtUIOp>(toType, operand); |
87 | return b.create<arith::ExtSIOp>(toType, operand); |
88 | } |
89 | if (toType.getWidth() < fromIntType.getWidth()) |
90 | return b.create<arith::TruncIOp>(toType, operand); |
91 | return operand; |
92 | } |
93 | |
94 | return {}; |
95 | } |
96 | |
97 | static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, |
98 | FloatType toType, bool isUnsigned) { |
99 | // If operand is integer, cast directly to the float type. |
100 | // Note that it is unclear how to cast from BF16<->FP16. |
101 | if (isa<IntegerType>(Val: operand.getType())) { |
102 | if (isUnsigned) |
103 | return b.create<arith::UIToFPOp>(toType, operand); |
104 | return b.create<arith::SIToFPOp>(toType, operand); |
105 | } |
106 | if (auto fromFpTy = dyn_cast<FloatType>(Val: operand.getType())) { |
107 | if (toType.getWidth() > fromFpTy.getWidth()) |
108 | return b.create<arith::ExtFOp>(toType, operand); |
109 | if (toType.getWidth() < fromFpTy.getWidth()) |
110 | return b.create<arith::TruncFOp>(toType, operand); |
111 | return operand; |
112 | } |
113 | |
114 | return {}; |
115 | } |
116 | |
117 | static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, |
118 | ComplexType targetType, |
119 | bool isUnsigned) { |
120 | if (auto fromComplexType = dyn_cast<ComplexType>(operand.getType())) { |
121 | if (isa<FloatType>(targetType.getElementType()) && |
122 | isa<FloatType>(fromComplexType.getElementType())) { |
123 | Value real = b.create<complex::ReOp>(operand); |
124 | Value imag = b.create<complex::ImOp>(operand); |
125 | Type targetETy = targetType.getElementType(); |
126 | if (targetType.getElementType().getIntOrFloatBitWidth() < |
127 | fromComplexType.getElementType().getIntOrFloatBitWidth()) { |
128 | real = b.create<arith::TruncFOp>(targetETy, real); |
129 | imag = b.create<arith::TruncFOp>(targetETy, imag); |
130 | } else { |
131 | real = b.create<arith::ExtFOp>(targetETy, real); |
132 | imag = b.create<arith::ExtFOp>(targetETy, imag); |
133 | } |
134 | return b.create<complex::CreateOp>(targetType, real, imag); |
135 | } |
136 | } |
137 | |
138 | if (dyn_cast<FloatType>(Val: operand.getType())) { |
139 | FloatType toFpTy = cast<FloatType>(targetType.getElementType()); |
140 | auto toBitwidth = toFpTy.getIntOrFloatBitWidth(); |
141 | Value from = operand; |
142 | if (from.getType().getIntOrFloatBitWidth() < toBitwidth) { |
143 | from = b.create<arith::ExtFOp>(toFpTy, from); |
144 | } |
145 | if (from.getType().getIntOrFloatBitWidth() > toBitwidth) { |
146 | from = b.create<arith::TruncFOp>(toFpTy, from); |
147 | } |
148 | Value zero = b.create<mlir::arith::ConstantFloatOp>( |
149 | args: mlir::APFloat(toFpTy.getFloatSemantics(), 0), args&: toFpTy); |
150 | return b.create<complex::CreateOp>(targetType, from, zero); |
151 | } |
152 | |
153 | if (dyn_cast<IntegerType>(operand.getType())) { |
154 | FloatType toFpTy = cast<FloatType>(targetType.getElementType()); |
155 | Value from = operand; |
156 | if (isUnsigned) { |
157 | from = b.create<arith::UIToFPOp>(toFpTy, from); |
158 | } else { |
159 | from = b.create<arith::SIToFPOp>(toFpTy, from); |
160 | } |
161 | Value zero = b.create<mlir::arith::ConstantFloatOp>( |
162 | args: mlir::APFloat(toFpTy.getFloatSemantics(), 0), args&: toFpTy); |
163 | return b.create<complex::CreateOp>(targetType, from, zero); |
164 | } |
165 | |
166 | return {}; |
167 | } |
168 | |
169 | Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand, |
170 | Type toType, bool isUnsignedCast) { |
171 | if (operand.getType() == toType) |
172 | return operand; |
173 | ImplicitLocOpBuilder ib(loc, b); |
174 | Value result; |
175 | if (auto intTy = dyn_cast<IntegerType>(toType)) { |
176 | result = convertScalarToIntDtype(ib, operand, intTy, isUnsignedCast); |
177 | } else if (auto floatTy = dyn_cast<FloatType>(Val&: toType)) { |
178 | result = convertScalarToFpDtype(b&: ib, operand, toType: floatTy, isUnsigned: isUnsignedCast); |
179 | } else if (auto complexTy = dyn_cast<ComplexType>(toType)) { |
180 | result = |
181 | convertScalarToComplexDtype(ib, operand, complexTy, isUnsignedCast); |
182 | } |
183 | |
184 | if (result) |
185 | return result; |
186 | |
187 | emitWarning(loc) << "could not cast operand of type " << operand.getType() |
188 | << " to " << toType; |
189 | return operand; |
190 | } |
191 | |
192 | SmallVector<Value> |
193 | mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, |
194 | ArrayRef<OpFoldResult> valueOrAttrVec) { |
195 | return llvm::to_vector<4>( |
196 | Range: llvm::map_range(C&: valueOrAttrVec, F: [&](OpFoldResult value) -> Value { |
197 | return getValueOrCreateConstantIndexOp(b, loc, ofr: value); |
198 | })); |
199 | } |
200 | |
201 | Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, |
202 | Type type, const APInt &value) { |
203 | TypedAttr attr; |
204 | if (isa<IntegerType>(Val: type)) { |
205 | attr = builder.getIntegerAttr(type, value); |
206 | } else { |
207 | auto vecTy = cast<ShapedType>(type); |
208 | attr = SplatElementsAttr::get(vecTy, value); |
209 | } |
210 | |
211 | return builder.create<arith::ConstantOp>(loc, attr); |
212 | } |
213 | |
214 | Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, |
215 | Type type, int64_t value) { |
216 | unsigned elementBitWidth = 0; |
217 | if (auto intTy = dyn_cast<IntegerType>(type)) |
218 | elementBitWidth = intTy.getWidth(); |
219 | else |
220 | elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth(); |
221 | |
222 | return createScalarOrSplatConstant(builder, loc, type, |
223 | value: APInt(elementBitWidth, value)); |
224 | } |
225 | |
226 | Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, |
227 | Type type, const APFloat &value) { |
228 | if (isa<FloatType>(type)) |
229 | return builder.createOrFold<arith::ConstantOp>( |
230 | loc, type, builder.getFloatAttr(type, value)); |
231 | TypedAttr splat = SplatElementsAttr::get(cast<ShapedType>(type), value); |
232 | return builder.createOrFold<arith::ConstantOp>(loc, type, splat); |
233 | } |
234 | |
235 | Value ArithBuilder::_and(Value lhs, Value rhs) { |
236 | return b.create<arith::AndIOp>(loc, lhs, rhs); |
237 | } |
238 | Value ArithBuilder::add(Value lhs, Value rhs) { |
239 | if (isa<FloatType>(lhs.getType())) |
240 | return b.create<arith::AddFOp>(loc, lhs, rhs); |
241 | return b.create<arith::AddIOp>(loc, lhs, rhs); |
242 | } |
243 | Value ArithBuilder::sub(Value lhs, Value rhs) { |
244 | if (isa<FloatType>(lhs.getType())) |
245 | return b.create<arith::SubFOp>(loc, lhs, rhs); |
246 | return b.create<arith::SubIOp>(loc, lhs, rhs); |
247 | } |
248 | Value ArithBuilder::mul(Value lhs, Value rhs) { |
249 | if (isa<FloatType>(lhs.getType())) |
250 | return b.create<arith::MulFOp>(loc, lhs, rhs); |
251 | return b.create<arith::MulIOp>(loc, lhs, rhs); |
252 | } |
253 | Value ArithBuilder::sgt(Value lhs, Value rhs) { |
254 | if (isa<FloatType>(lhs.getType())) |
255 | return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs); |
256 | return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs); |
257 | } |
258 | Value ArithBuilder::slt(Value lhs, Value rhs) { |
259 | if (isa<FloatType>(lhs.getType())) |
260 | return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs); |
261 | return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs); |
262 | } |
263 | Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) { |
264 | return b.create<arith::SelectOp>(loc, cmp, lhs, rhs); |
265 | } |
266 | |
267 | namespace mlir::arith { |
268 | |
269 | Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values) { |
270 | return createProduct(builder, loc, values, resultType: values.front().getType()); |
271 | } |
272 | |
273 | Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values, |
274 | Type resultType) { |
275 | Value one = builder.create<ConstantOp>(loc, resultType, |
276 | builder.getOneAttr(resultType)); |
277 | ArithBuilder arithBuilder(builder, loc); |
278 | return std::accumulate( |
279 | first: values.begin(), last: values.end(), init: one, |
280 | binary_op: [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(lhs: acc, rhs: v); }); |
281 | } |
282 | |
283 | } // namespace mlir::arith |
284 | |