1//===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/Utils/Utils.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Complex/IR/Complex.h"
16#include "mlir/Dialect/Utils/StaticValueUtils.h"
17#include "mlir/IR/ImplicitLocOpBuilder.h"
18#include "llvm/ADT/SmallBitVector.h"
19#include <numeric>
20
21using namespace mlir;
22
23std::optional<SmallVector<OpFoldResult>>
24mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc,
25 ShapedType expandedType,
26 ArrayRef<ReassociationIndices> reassociation,
27 ArrayRef<OpFoldResult> inputShape) {
28
29 SmallVector<Value> outputShapeValues;
30 SmallVector<int64_t> outputShapeInts;
31 // For zero-rank inputs, all dims in result shape are unit extent.
32 if (inputShape.empty()) {
33 outputShapeInts.resize(expandedType.getRank(), 1);
34 return getMixedValues(staticValues: outputShapeInts, dynamicValues: outputShapeValues, b);
35 }
36
37 // Check for all static shapes.
38 if (expandedType.hasStaticShape()) {
39 ArrayRef<int64_t> staticShape = expandedType.getShape();
40 outputShapeInts.assign(in_start: staticShape.begin(), in_end: staticShape.end());
41 return getMixedValues(staticValues: outputShapeInts, dynamicValues: outputShapeValues, b);
42 }
43
44 outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
45 for (const auto &it : llvm::enumerate(First&: reassociation)) {
46 ReassociationIndices indexGroup = it.value();
47
48 int64_t indexGroupStaticSizesProductInt = 1;
49 bool foundDynamicShape = false;
50 for (int64_t index : indexGroup) {
51 int64_t outputDimSize = expandedType.getDimSize(index);
52 // Cannot infer expanded shape with multiple dynamic dims in the
53 // same reassociation group!
54 if (ShapedType::isDynamic(outputDimSize)) {
55 if (foundDynamicShape)
56 return std::nullopt;
57 foundDynamicShape = true;
58 } else {
59 outputShapeInts[index] = outputDimSize;
60 indexGroupStaticSizesProductInt *= outputDimSize;
61 }
62 }
63 if (!foundDynamicShape)
64 continue;
65
66 int64_t inputIndex = it.index();
67 // Call get<Value>() under the assumption that we're not casting
68 // dynamism.
69 Value indexGroupSize = cast<Value>(Val: inputShape[inputIndex]);
70 Value indexGroupStaticSizesProduct =
71 b.create<arith::ConstantIndexOp>(location: loc, args&: indexGroupStaticSizesProductInt);
72 Value dynamicDimSize = b.createOrFold<arith::DivSIOp>(
73 loc, indexGroupSize, indexGroupStaticSizesProduct);
74 outputShapeValues.push_back(Elt: dynamicDimSize);
75 }
76
77 if ((int64_t)outputShapeValues.size() !=
78 llvm::count(outputShapeInts, ShapedType::kDynamic))
79 return std::nullopt;
80
81 return getMixedValues(staticValues: outputShapeInts, dynamicValues: outputShapeValues, b);
82}
83
84/// Matches a ConstantIndexOp.
85/// TODO: This should probably just be a general matcher that uses matchConstant
86/// and checks the operation for an index type.
87detail::op_matcher<arith::ConstantIndexOp> mlir::matchConstantIndex() {
88 return detail::op_matcher<arith::ConstantIndexOp>();
89}
90
91llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
92 ArrayRef<int64_t> shape) {
93 llvm::SmallBitVector dimsToProject(shape.size());
94 for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
95 if (shape[pos] == 1) {
96 dimsToProject.set(pos);
97 --rank;
98 }
99 }
100 return dimsToProject;
101}
102
103Value mlir::getValueOrCreateConstantIntOp(OpBuilder &b, Location loc,
104 OpFoldResult ofr) {
105 if (auto value = dyn_cast_if_present<Value>(Val&: ofr))
106 return value;
107 auto attr = cast<IntegerAttr>(cast<Attribute>(Val&: ofr));
108 return b.create<arith::ConstantOp>(
109 loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue()));
110}
111
112Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
113 OpFoldResult ofr) {
114 if (auto value = dyn_cast_if_present<Value>(Val&: ofr))
115 return value;
116 auto attr = cast<IntegerAttr>(cast<Attribute>(Val&: ofr));
117 return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
118}
119
120Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
121 Type targetType, Value value) {
122 if (targetType == value.getType())
123 return value;
124
125 bool targetIsIndex = targetType.isIndex();
126 bool valueIsIndex = value.getType().isIndex();
127 if (targetIsIndex ^ valueIsIndex)
128 return b.create<arith::IndexCastOp>(loc, targetType, value);
129
130 auto targetIntegerType = dyn_cast<IntegerType>(targetType);
131 auto valueIntegerType = dyn_cast<IntegerType>(value.getType());
132 assert(targetIntegerType && valueIntegerType &&
133 "unexpected cast between types other than integers and index");
134 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
135
136 if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
137 return b.create<arith::ExtSIOp>(loc, targetIntegerType, value);
138 return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
139}
140
141static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand,
142 IntegerType toType, bool isUnsigned) {
143 // If operand is floating point, cast directly to the int type.
144 if (isa<FloatType>(Val: operand.getType())) {
145 if (isUnsigned)
146 return b.create<arith::FPToUIOp>(toType, operand);
147 return b.create<arith::FPToSIOp>(toType, operand);
148 }
149 // Cast index operands directly to the int type.
150 if (operand.getType().isIndex())
151 return b.create<arith::IndexCastOp>(toType, operand);
152 if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) {
153 // Either extend or truncate.
154 if (toType.getWidth() > fromIntType.getWidth()) {
155 if (isUnsigned)
156 return b.create<arith::ExtUIOp>(toType, operand);
157 return b.create<arith::ExtSIOp>(toType, operand);
158 }
159 if (toType.getWidth() < fromIntType.getWidth())
160 return b.create<arith::TruncIOp>(toType, operand);
161 return operand;
162 }
163
164 return {};
165}
166
167static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand,
168 FloatType toType, bool isUnsigned) {
169 // If operand is integer, cast directly to the float type.
170 // Note that it is unclear how to cast from BF16<->FP16.
171 if (isa<IntegerType>(Val: operand.getType())) {
172 if (isUnsigned)
173 return b.create<arith::UIToFPOp>(toType, operand);
174 return b.create<arith::SIToFPOp>(toType, operand);
175 }
176 if (auto fromFpTy = dyn_cast<FloatType>(operand.getType())) {
177 if (toType.getWidth() > fromFpTy.getWidth())
178 return b.create<arith::ExtFOp>(toType, operand);
179 if (toType.getWidth() < fromFpTy.getWidth())
180 return b.create<arith::TruncFOp>(toType, operand);
181 return operand;
182 }
183
184 return {};
185}
186
187static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand,
188 ComplexType targetType,
189 bool isUnsigned) {
190 if (auto fromComplexType = dyn_cast<ComplexType>(operand.getType())) {
191 if (isa<FloatType>(targetType.getElementType()) &&
192 isa<FloatType>(fromComplexType.getElementType())) {
193 Value real = b.create<complex::ReOp>(operand);
194 Value imag = b.create<complex::ImOp>(operand);
195 Type targetETy = targetType.getElementType();
196 if (targetType.getElementType().getIntOrFloatBitWidth() <
197 fromComplexType.getElementType().getIntOrFloatBitWidth()) {
198 real = b.create<arith::TruncFOp>(targetETy, real);
199 imag = b.create<arith::TruncFOp>(targetETy, imag);
200 } else {
201 real = b.create<arith::ExtFOp>(targetETy, real);
202 imag = b.create<arith::ExtFOp>(targetETy, imag);
203 }
204 return b.create<complex::CreateOp>(targetType, real, imag);
205 }
206 }
207
208 if (isa<FloatType>(Val: operand.getType())) {
209 FloatType toFpTy = cast<FloatType>(targetType.getElementType());
210 auto toBitwidth = toFpTy.getIntOrFloatBitWidth();
211 Value from = operand;
212 if (from.getType().getIntOrFloatBitWidth() < toBitwidth) {
213 from = b.create<arith::ExtFOp>(toFpTy, from);
214 }
215 if (from.getType().getIntOrFloatBitWidth() > toBitwidth) {
216 from = b.create<arith::TruncFOp>(toFpTy, from);
217 }
218 Value zero = b.create<mlir::arith::ConstantFloatOp>(
219 mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
220 return b.create<complex::CreateOp>(targetType, from, zero);
221 }
222
223 if (isa<IntegerType>(Val: operand.getType())) {
224 FloatType toFpTy = cast<FloatType>(targetType.getElementType());
225 Value from = operand;
226 if (isUnsigned) {
227 from = b.create<arith::UIToFPOp>(toFpTy, from);
228 } else {
229 from = b.create<arith::SIToFPOp>(toFpTy, from);
230 }
231 Value zero = b.create<mlir::arith::ConstantFloatOp>(
232 mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
233 return b.create<complex::CreateOp>(targetType, from, zero);
234 }
235
236 return {};
237}
238
239Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
240 Type toType, bool isUnsignedCast) {
241 if (operand.getType() == toType)
242 return operand;
243 ImplicitLocOpBuilder ib(loc, b);
244 Value result;
245 if (auto intTy = dyn_cast<IntegerType>(toType)) {
246 result = convertScalarToIntDtype(ib, operand, intTy, isUnsignedCast);
247 } else if (auto floatTy = dyn_cast<FloatType>(toType)) {
248 result = convertScalarToFpDtype(ib, operand, floatTy, isUnsignedCast);
249 } else if (auto complexTy = dyn_cast<ComplexType>(toType)) {
250 result =
251 convertScalarToComplexDtype(ib, operand, complexTy, isUnsignedCast);
252 }
253
254 if (result)
255 return result;
256
257 emitWarning(loc) << "could not cast operand of type " << operand.getType()
258 << " to " << toType;
259 return operand;
260}
261
262SmallVector<Value>
263mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
264 ArrayRef<OpFoldResult> valueOrAttrVec) {
265 return llvm::to_vector<4>(
266 Range: llvm::map_range(C&: valueOrAttrVec, F: [&](OpFoldResult value) -> Value {
267 return getValueOrCreateConstantIndexOp(b, loc, ofr: value);
268 }));
269}
270
271Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
272 Type type, const APInt &value) {
273 TypedAttr attr;
274 if (isa<IntegerType>(Val: type)) {
275 attr = builder.getIntegerAttr(type, value);
276 } else {
277 auto vecTy = cast<ShapedType>(type);
278 attr = SplatElementsAttr::get(vecTy, value);
279 }
280
281 return builder.create<arith::ConstantOp>(loc, attr);
282}
283
284Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
285 Type type, int64_t value) {
286 unsigned elementBitWidth = 0;
287 if (auto intTy = dyn_cast<IntegerType>(type))
288 elementBitWidth = intTy.getWidth();
289 else
290 elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth();
291
292 return createScalarOrSplatConstant(builder, loc, type,
293 value: APInt(elementBitWidth, value));
294}
295
296Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
297 Type type, const APFloat &value) {
298 if (isa<FloatType>(type))
299 return builder.createOrFold<arith::ConstantOp>(
300 loc, type, builder.getFloatAttr(type, value));
301 TypedAttr splat = SplatElementsAttr::get(cast<ShapedType>(type), value);
302 return builder.createOrFold<arith::ConstantOp>(loc, type, splat);
303}
304
305Type mlir::getType(OpFoldResult ofr) {
306 if (auto value = dyn_cast_if_present<Value>(Val&: ofr))
307 return value.getType();
308 auto attr = cast<IntegerAttr>(cast<Attribute>(Val&: ofr));
309 return attr.getType();
310}
311
312Value ArithBuilder::_and(Value lhs, Value rhs) {
313 return b.create<arith::AndIOp>(loc, lhs, rhs);
314}
315Value ArithBuilder::add(Value lhs, Value rhs) {
316 if (isa<FloatType>(lhs.getType()))
317 return b.create<arith::AddFOp>(loc, lhs, rhs);
318 return b.create<arith::AddIOp>(loc, lhs, rhs, ovf);
319}
320Value ArithBuilder::sub(Value lhs, Value rhs) {
321 if (isa<FloatType>(lhs.getType()))
322 return b.create<arith::SubFOp>(loc, lhs, rhs);
323 return b.create<arith::SubIOp>(loc, lhs, rhs, ovf);
324}
325Value ArithBuilder::mul(Value lhs, Value rhs) {
326 if (isa<FloatType>(lhs.getType()))
327 return b.create<arith::MulFOp>(loc, lhs, rhs);
328 return b.create<arith::MulIOp>(loc, lhs, rhs, ovf);
329}
330Value ArithBuilder::sgt(Value lhs, Value rhs) {
331 if (isa<FloatType>(lhs.getType()))
332 return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
333 return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs);
334}
335Value ArithBuilder::slt(Value lhs, Value rhs) {
336 if (isa<FloatType>(lhs.getType()))
337 return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
338 return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs);
339}
340Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
341 return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
342}
343
344namespace mlir::arith {
345
346Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values) {
347 return createProduct(builder, loc, values, resultType: values.front().getType());
348}
349
350Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
351 Type resultType) {
352 Value one = builder.create<ConstantOp>(loc, resultType,
353 builder.getOneAttr(resultType));
354 ArithBuilder arithBuilder(builder, loc);
355 return std::accumulate(
356 first: values.begin(), last: values.end(), init: one,
357 binary_op: [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
358}
359
360/// Map strings to float types.
361std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
362 Builder b(ctx);
363 return llvm::StringSwitch<std::optional<FloatType>>(name)
364 .Case("f4E2M1FN", b.getType<Float4E2M1FNType>())
365 .Case("f6E2M3FN", b.getType<Float6E2M3FNType>())
366 .Case("f6E3M2FN", b.getType<Float6E3M2FNType>())
367 .Case("f8E5M2", b.getType<Float8E5M2Type>())
368 .Case("f8E4M3", b.getType<Float8E4M3Type>())
369 .Case("f8E4M3FN", b.getType<Float8E4M3FNType>())
370 .Case("f8E5M2FNUZ", b.getType<Float8E5M2FNUZType>())
371 .Case("f8E4M3FNUZ", b.getType<Float8E4M3FNUZType>())
372 .Case("f8E3M4", b.getType<Float8E3M4Type>())
373 .Case("f8E8M0FNU", b.getType<Float8E8M0FNUType>())
374 .Case("bf16", b.getType<BFloat16Type>())
375 .Case("f16", b.getType<Float16Type>())
376 .Case("f32", b.getType<Float32Type>())
377 .Case("f64", b.getType<Float64Type>())
378 .Case("f80", b.getType<Float80Type>())
379 .Case("f128", b.getType<Float128Type>())
380 .Default(std::nullopt);
381}
382
383} // namespace mlir::arith
384

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Arith/Utils/Utils.cpp