1 | //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements utilities for the Linalg dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/Complex/IR/Complex.h" |
16 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
17 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
18 | #include "llvm/ADT/SmallBitVector.h" |
19 | #include <numeric> |
20 | |
21 | using namespace mlir; |
22 | |
23 | std::optional<SmallVector<OpFoldResult>> |
24 | mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc, |
25 | ShapedType expandedType, |
26 | ArrayRef<ReassociationIndices> reassociation, |
27 | ArrayRef<OpFoldResult> inputShape) { |
28 | |
29 | SmallVector<Value> outputShapeValues; |
30 | SmallVector<int64_t> outputShapeInts; |
31 | // For zero-rank inputs, all dims in result shape are unit extent. |
32 | if (inputShape.empty()) { |
33 | outputShapeInts.resize(expandedType.getRank(), 1); |
34 | return getMixedValues(staticValues: outputShapeInts, dynamicValues: outputShapeValues, b); |
35 | } |
36 | |
37 | // Check for all static shapes. |
38 | if (expandedType.hasStaticShape()) { |
39 | ArrayRef<int64_t> staticShape = expandedType.getShape(); |
40 | outputShapeInts.assign(in_start: staticShape.begin(), in_end: staticShape.end()); |
41 | return getMixedValues(staticValues: outputShapeInts, dynamicValues: outputShapeValues, b); |
42 | } |
43 | |
44 | outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic); |
45 | for (const auto &it : llvm::enumerate(First&: reassociation)) { |
46 | ReassociationIndices indexGroup = it.value(); |
47 | |
48 | int64_t indexGroupStaticSizesProductInt = 1; |
49 | bool foundDynamicShape = false; |
50 | for (int64_t index : indexGroup) { |
51 | int64_t outputDimSize = expandedType.getDimSize(index); |
52 | // Cannot infer expanded shape with multiple dynamic dims in the |
53 | // same reassociation group! |
54 | if (ShapedType::isDynamic(outputDimSize)) { |
55 | if (foundDynamicShape) |
56 | return std::nullopt; |
57 | foundDynamicShape = true; |
58 | } else { |
59 | outputShapeInts[index] = outputDimSize; |
60 | indexGroupStaticSizesProductInt *= outputDimSize; |
61 | } |
62 | } |
63 | if (!foundDynamicShape) |
64 | continue; |
65 | |
66 | int64_t inputIndex = it.index(); |
67 | // Call get<Value>() under the assumption that we're not casting |
68 | // dynamism. |
69 | Value indexGroupSize = cast<Value>(Val: inputShape[inputIndex]); |
70 | Value indexGroupStaticSizesProduct = |
71 | b.create<arith::ConstantIndexOp>(location: loc, args&: indexGroupStaticSizesProductInt); |
72 | Value dynamicDimSize = b.createOrFold<arith::DivSIOp>( |
73 | loc, indexGroupSize, indexGroupStaticSizesProduct); |
74 | outputShapeValues.push_back(Elt: dynamicDimSize); |
75 | } |
76 | |
77 | if ((int64_t)outputShapeValues.size() != |
78 | llvm::count(outputShapeInts, ShapedType::kDynamic)) |
79 | return std::nullopt; |
80 | |
81 | return getMixedValues(staticValues: outputShapeInts, dynamicValues: outputShapeValues, b); |
82 | } |
83 | |
84 | /// Matches a ConstantIndexOp. |
85 | /// TODO: This should probably just be a general matcher that uses matchConstant |
86 | /// and checks the operation for an index type. |
87 | detail::op_matcher<arith::ConstantIndexOp> mlir::matchConstantIndex() { |
88 | return detail::op_matcher<arith::ConstantIndexOp>(); |
89 | } |
90 | |
91 | llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank, |
92 | ArrayRef<int64_t> shape) { |
93 | llvm::SmallBitVector dimsToProject(shape.size()); |
94 | for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) { |
95 | if (shape[pos] == 1) { |
96 | dimsToProject.set(pos); |
97 | --rank; |
98 | } |
99 | } |
100 | return dimsToProject; |
101 | } |
102 | |
103 | Value mlir::getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, |
104 | OpFoldResult ofr) { |
105 | if (auto value = dyn_cast_if_present<Value>(Val&: ofr)) |
106 | return value; |
107 | auto attr = cast<IntegerAttr>(cast<Attribute>(Val&: ofr)); |
108 | return b.create<arith::ConstantOp>( |
109 | loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue())); |
110 | } |
111 | |
112 | Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, |
113 | OpFoldResult ofr) { |
114 | if (auto value = dyn_cast_if_present<Value>(Val&: ofr)) |
115 | return value; |
116 | auto attr = cast<IntegerAttr>(cast<Attribute>(Val&: ofr)); |
117 | return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue()); |
118 | } |
119 | |
120 | Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, |
121 | Type targetType, Value value) { |
122 | if (targetType == value.getType()) |
123 | return value; |
124 | |
125 | bool targetIsIndex = targetType.isIndex(); |
126 | bool valueIsIndex = value.getType().isIndex(); |
127 | if (targetIsIndex ^ valueIsIndex) |
128 | return b.create<arith::IndexCastOp>(loc, targetType, value); |
129 | |
130 | auto targetIntegerType = dyn_cast<IntegerType>(targetType); |
131 | auto valueIntegerType = dyn_cast<IntegerType>(value.getType()); |
132 | assert(targetIntegerType && valueIntegerType && |
133 | "unexpected cast between types other than integers and index"); |
134 | assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); |
135 | |
136 | if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) |
137 | return b.create<arith::ExtSIOp>(loc, targetIntegerType, value); |
138 | return b.create<arith::TruncIOp>(loc, targetIntegerType, value); |
139 | } |
140 | |
141 | static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, |
142 | IntegerType toType, bool isUnsigned) { |
143 | // If operand is floating point, cast directly to the int type. |
144 | if (isa<FloatType>(Val: operand.getType())) { |
145 | if (isUnsigned) |
146 | return b.create<arith::FPToUIOp>(toType, operand); |
147 | return b.create<arith::FPToSIOp>(toType, operand); |
148 | } |
149 | // Cast index operands directly to the int type. |
150 | if (operand.getType().isIndex()) |
151 | return b.create<arith::IndexCastOp>(toType, operand); |
152 | if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) { |
153 | // Either extend or truncate. |
154 | if (toType.getWidth() > fromIntType.getWidth()) { |
155 | if (isUnsigned) |
156 | return b.create<arith::ExtUIOp>(toType, operand); |
157 | return b.create<arith::ExtSIOp>(toType, operand); |
158 | } |
159 | if (toType.getWidth() < fromIntType.getWidth()) |
160 | return b.create<arith::TruncIOp>(toType, operand); |
161 | return operand; |
162 | } |
163 | |
164 | return {}; |
165 | } |
166 | |
167 | static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, |
168 | FloatType toType, bool isUnsigned) { |
169 | // If operand is integer, cast directly to the float type. |
170 | // Note that it is unclear how to cast from BF16<->FP16. |
171 | if (isa<IntegerType>(Val: operand.getType())) { |
172 | if (isUnsigned) |
173 | return b.create<arith::UIToFPOp>(toType, operand); |
174 | return b.create<arith::SIToFPOp>(toType, operand); |
175 | } |
176 | if (auto fromFpTy = dyn_cast<FloatType>(operand.getType())) { |
177 | if (toType.getWidth() > fromFpTy.getWidth()) |
178 | return b.create<arith::ExtFOp>(toType, operand); |
179 | if (toType.getWidth() < fromFpTy.getWidth()) |
180 | return b.create<arith::TruncFOp>(toType, operand); |
181 | return operand; |
182 | } |
183 | |
184 | return {}; |
185 | } |
186 | |
187 | static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, |
188 | ComplexType targetType, |
189 | bool isUnsigned) { |
190 | if (auto fromComplexType = dyn_cast<ComplexType>(operand.getType())) { |
191 | if (isa<FloatType>(targetType.getElementType()) && |
192 | isa<FloatType>(fromComplexType.getElementType())) { |
193 | Value real = b.create<complex::ReOp>(operand); |
194 | Value imag = b.create<complex::ImOp>(operand); |
195 | Type targetETy = targetType.getElementType(); |
196 | if (targetType.getElementType().getIntOrFloatBitWidth() < |
197 | fromComplexType.getElementType().getIntOrFloatBitWidth()) { |
198 | real = b.create<arith::TruncFOp>(targetETy, real); |
199 | imag = b.create<arith::TruncFOp>(targetETy, imag); |
200 | } else { |
201 | real = b.create<arith::ExtFOp>(targetETy, real); |
202 | imag = b.create<arith::ExtFOp>(targetETy, imag); |
203 | } |
204 | return b.create<complex::CreateOp>(targetType, real, imag); |
205 | } |
206 | } |
207 | |
208 | if (isa<FloatType>(Val: operand.getType())) { |
209 | FloatType toFpTy = cast<FloatType>(targetType.getElementType()); |
210 | auto toBitwidth = toFpTy.getIntOrFloatBitWidth(); |
211 | Value from = operand; |
212 | if (from.getType().getIntOrFloatBitWidth() < toBitwidth) { |
213 | from = b.create<arith::ExtFOp>(toFpTy, from); |
214 | } |
215 | if (from.getType().getIntOrFloatBitWidth() > toBitwidth) { |
216 | from = b.create<arith::TruncFOp>(toFpTy, from); |
217 | } |
218 | Value zero = b.create<mlir::arith::ConstantFloatOp>( |
219 | mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy); |
220 | return b.create<complex::CreateOp>(targetType, from, zero); |
221 | } |
222 | |
223 | if (isa<IntegerType>(Val: operand.getType())) { |
224 | FloatType toFpTy = cast<FloatType>(targetType.getElementType()); |
225 | Value from = operand; |
226 | if (isUnsigned) { |
227 | from = b.create<arith::UIToFPOp>(toFpTy, from); |
228 | } else { |
229 | from = b.create<arith::SIToFPOp>(toFpTy, from); |
230 | } |
231 | Value zero = b.create<mlir::arith::ConstantFloatOp>( |
232 | mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy); |
233 | return b.create<complex::CreateOp>(targetType, from, zero); |
234 | } |
235 | |
236 | return {}; |
237 | } |
238 | |
239 | Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand, |
240 | Type toType, bool isUnsignedCast) { |
241 | if (operand.getType() == toType) |
242 | return operand; |
243 | ImplicitLocOpBuilder ib(loc, b); |
244 | Value result; |
245 | if (auto intTy = dyn_cast<IntegerType>(toType)) { |
246 | result = convertScalarToIntDtype(ib, operand, intTy, isUnsignedCast); |
247 | } else if (auto floatTy = dyn_cast<FloatType>(toType)) { |
248 | result = convertScalarToFpDtype(ib, operand, floatTy, isUnsignedCast); |
249 | } else if (auto complexTy = dyn_cast<ComplexType>(toType)) { |
250 | result = |
251 | convertScalarToComplexDtype(ib, operand, complexTy, isUnsignedCast); |
252 | } |
253 | |
254 | if (result) |
255 | return result; |
256 | |
257 | emitWarning(loc) << "could not cast operand of type "<< operand.getType() |
258 | << " to "<< toType; |
259 | return operand; |
260 | } |
261 | |
262 | SmallVector<Value> |
263 | mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, |
264 | ArrayRef<OpFoldResult> valueOrAttrVec) { |
265 | return llvm::to_vector<4>( |
266 | Range: llvm::map_range(C&: valueOrAttrVec, F: [&](OpFoldResult value) -> Value { |
267 | return getValueOrCreateConstantIndexOp(b, loc, ofr: value); |
268 | })); |
269 | } |
270 | |
271 | Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, |
272 | Type type, const APInt &value) { |
273 | TypedAttr attr; |
274 | if (isa<IntegerType>(Val: type)) { |
275 | attr = builder.getIntegerAttr(type, value); |
276 | } else { |
277 | auto vecTy = cast<ShapedType>(type); |
278 | attr = SplatElementsAttr::get(vecTy, value); |
279 | } |
280 | |
281 | return builder.create<arith::ConstantOp>(loc, attr); |
282 | } |
283 | |
284 | Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, |
285 | Type type, int64_t value) { |
286 | unsigned elementBitWidth = 0; |
287 | if (auto intTy = dyn_cast<IntegerType>(type)) |
288 | elementBitWidth = intTy.getWidth(); |
289 | else |
290 | elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth(); |
291 | |
292 | return createScalarOrSplatConstant(builder, loc, type, |
293 | value: APInt(elementBitWidth, value)); |
294 | } |
295 | |
296 | Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, |
297 | Type type, const APFloat &value) { |
298 | if (isa<FloatType>(type)) |
299 | return builder.createOrFold<arith::ConstantOp>( |
300 | loc, type, builder.getFloatAttr(type, value)); |
301 | TypedAttr splat = SplatElementsAttr::get(cast<ShapedType>(type), value); |
302 | return builder.createOrFold<arith::ConstantOp>(loc, type, splat); |
303 | } |
304 | |
305 | Type mlir::getType(OpFoldResult ofr) { |
306 | if (auto value = dyn_cast_if_present<Value>(Val&: ofr)) |
307 | return value.getType(); |
308 | auto attr = cast<IntegerAttr>(cast<Attribute>(Val&: ofr)); |
309 | return attr.getType(); |
310 | } |
311 | |
312 | Value ArithBuilder::_and(Value lhs, Value rhs) { |
313 | return b.create<arith::AndIOp>(loc, lhs, rhs); |
314 | } |
315 | Value ArithBuilder::add(Value lhs, Value rhs) { |
316 | if (isa<FloatType>(lhs.getType())) |
317 | return b.create<arith::AddFOp>(loc, lhs, rhs); |
318 | return b.create<arith::AddIOp>(loc, lhs, rhs, ovf); |
319 | } |
320 | Value ArithBuilder::sub(Value lhs, Value rhs) { |
321 | if (isa<FloatType>(lhs.getType())) |
322 | return b.create<arith::SubFOp>(loc, lhs, rhs); |
323 | return b.create<arith::SubIOp>(loc, lhs, rhs, ovf); |
324 | } |
325 | Value ArithBuilder::mul(Value lhs, Value rhs) { |
326 | if (isa<FloatType>(lhs.getType())) |
327 | return b.create<arith::MulFOp>(loc, lhs, rhs); |
328 | return b.create<arith::MulIOp>(loc, lhs, rhs, ovf); |
329 | } |
330 | Value ArithBuilder::sgt(Value lhs, Value rhs) { |
331 | if (isa<FloatType>(lhs.getType())) |
332 | return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs); |
333 | return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs); |
334 | } |
335 | Value ArithBuilder::slt(Value lhs, Value rhs) { |
336 | if (isa<FloatType>(lhs.getType())) |
337 | return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs); |
338 | return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs); |
339 | } |
340 | Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) { |
341 | return b.create<arith::SelectOp>(loc, cmp, lhs, rhs); |
342 | } |
343 | |
344 | namespace mlir::arith { |
345 | |
346 | Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values) { |
347 | return createProduct(builder, loc, values, resultType: values.front().getType()); |
348 | } |
349 | |
350 | Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values, |
351 | Type resultType) { |
352 | Value one = builder.create<ConstantOp>(loc, resultType, |
353 | builder.getOneAttr(resultType)); |
354 | ArithBuilder arithBuilder(builder, loc); |
355 | return std::accumulate( |
356 | first: values.begin(), last: values.end(), init: one, |
357 | binary_op: [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); }); |
358 | } |
359 | |
360 | /// Map strings to float types. |
361 | std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) { |
362 | Builder b(ctx); |
363 | return llvm::StringSwitch<std::optional<FloatType>>(name) |
364 | .Case("f4E2M1FN", b.getType<Float4E2M1FNType>()) |
365 | .Case("f6E2M3FN", b.getType<Float6E2M3FNType>()) |
366 | .Case("f6E3M2FN", b.getType<Float6E3M2FNType>()) |
367 | .Case("f8E5M2", b.getType<Float8E5M2Type>()) |
368 | .Case("f8E4M3", b.getType<Float8E4M3Type>()) |
369 | .Case("f8E4M3FN", b.getType<Float8E4M3FNType>()) |
370 | .Case("f8E5M2FNUZ", b.getType<Float8E5M2FNUZType>()) |
371 | .Case("f8E4M3FNUZ", b.getType<Float8E4M3FNUZType>()) |
372 | .Case("f8E3M4", b.getType<Float8E3M4Type>()) |
373 | .Case("f8E8M0FNU", b.getType<Float8E8M0FNUType>()) |
374 | .Case("bf16", b.getType<BFloat16Type>()) |
375 | .Case("f16", b.getType<Float16Type>()) |
376 | .Case("f32", b.getType<Float32Type>()) |
377 | .Case("f64", b.getType<Float64Type>()) |
378 | .Case("f80", b.getType<Float80Type>()) |
379 | .Case("f128", b.getType<Float128Type>()) |
380 | .Default(std::nullopt); |
381 | } |
382 | |
383 | } // namespace mlir::arith |
384 |
Definitions
- inferExpandShapeOutputShape
- matchConstantIndex
- getPositionsOfShapeOne
- getValueOrCreateConstantIntOp
- getValueOrCreateConstantIndexOp
- getValueOrCreateCastToIndexLike
- convertScalarToIntDtype
- convertScalarToFpDtype
- convertScalarToComplexDtype
- convertScalarToDtype
- getValueOrCreateConstantIndexOp
- createScalarOrSplatConstant
- createScalarOrSplatConstant
- createScalarOrSplatConstant
- getType
- _and
- add
- sub
- mul
- sgt
- slt
- select
- createProduct
- createProduct
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more