1//===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/Transforms/Passes.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
13#include "mlir/Dialect/Arith/Utils/Utils.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
16#include "mlir/Dialect/Vector/IR/VectorOps.h"
17#include "mlir/IR/BuiltinTypes.h"
18#include "mlir/IR/TypeUtilities.h"
19#include "mlir/Transforms/DialectConversion.h"
20#include "llvm/ADT/APFloat.h"
21#include "llvm/ADT/APInt.h"
22#include "llvm/Support/FormatVariadic.h"
23#include "llvm/Support/MathExtras.h"
24#include <cassert>
25
26namespace mlir::arith {
27#define GEN_PASS_DEF_ARITHEMULATEWIDEINT
28#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
29} // namespace mlir::arith
30
31using namespace mlir;
32
33//===----------------------------------------------------------------------===//
34// Common Helper Functions
35//===----------------------------------------------------------------------===//
36
37/// Returns N bottom and N top bits from `value`, where N = `newBitWidth`.
38/// Treats `value` as a 2*N bits-wide integer.
39/// The bottom bits are returned in the first pair element, while the top bits
40/// in the second one.
41static std::pair<APInt, APInt> getHalves(const APInt &value,
42 unsigned newBitWidth) {
43 APInt low = value.extractBits(numBits: newBitWidth, bitPosition: 0);
44 APInt high = value.extractBits(numBits: newBitWidth, bitPosition: newBitWidth);
45 return {std::move(low), std::move(high)};
46}
47
48/// Returns the type with the last (innermost) dimension reduced to x1.
49/// Scalarizes 1D vector inputs to match how we extract/insert vector values,
50/// e.g.:
51/// - vector<3x2xi16> --> vector<3x1xi16>
52/// - vector<2xi16> --> i16
53static Type reduceInnermostDim(VectorType type) {
54 if (type.getShape().size() == 1)
55 return type.getElementType();
56
57 auto newShape = to_vector(type.getShape());
58 newShape.back() = 1;
59 return VectorType::get(newShape, type.getElementType());
60}
61
62/// Extracts the `input` vector slice with elements at the last dimension offset
63/// by `lastOffset`. Returns a value of vector type with the last dimension
64/// reduced to x1 or fully scalarized, e.g.:
65/// - vector<3x2xi16> --> vector<3x1xi16>
66/// - vector<2xi16> --> i16
67static Value extractLastDimSlice(ConversionPatternRewriter &rewriter,
68 Location loc, Value input,
69 int64_t lastOffset) {
70 ArrayRef<int64_t> shape = cast<VectorType>(input.getType()).getShape();
71 assert(lastOffset < shape.back() && "Offset out of bounds");
72
73 // Scalarize the result in case of 1D vectors.
74 if (shape.size() == 1)
75 return rewriter.create<vector::ExtractOp>(loc, input, lastOffset);
76
77 SmallVector<int64_t> offsets(shape.size(), 0);
78 offsets.back() = lastOffset;
79 auto sizes = llvm::to_vector(Range&: shape);
80 sizes.back() = 1;
81 SmallVector<int64_t> strides(shape.size(), 1);
82
83 return rewriter.create<vector::ExtractStridedSliceOp>(loc, input, offsets,
84 sizes, strides);
85}
86
87/// Extracts two vector slices from the `input` whose type is `vector<...x2T>`,
88/// with the first element at offset 0 and the second element at offset 1.
89static std::pair<Value, Value>
90extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc,
91 Value input) {
92 return {extractLastDimSlice(rewriter, loc, input, lastOffset: 0),
93 extractLastDimSlice(rewriter, loc, input, lastOffset: 1)};
94}
95
96// Performs a vector shape cast to drop the trailing x1 dimension. If the
97// `input` is a scalar, this is a noop.
98static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter,
99 Location loc, Value input) {
100 auto vecTy = dyn_cast<VectorType>(input.getType());
101 if (!vecTy)
102 return input;
103
104 // Shape cast to drop the last x1 dimension.
105 ArrayRef<int64_t> shape = vecTy.getShape();
106 assert(shape.size() >= 2 && "Expected vector with at list two dims");
107 assert(shape.back() == 1 && "Expected the last vector dim to be x1");
108
109 auto newVecTy = VectorType::get(shape.drop_back(), vecTy.getElementType());
110 return rewriter.create<vector::ShapeCastOp>(loc, newVecTy, input);
111}
112
113/// Performs a vector shape cast to append an x1 dimension. If the
114/// `input` is a scalar, this is a noop.
115static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc,
116 Value input) {
117 auto vecTy = dyn_cast<VectorType>(input.getType());
118 if (!vecTy)
119 return input;
120
121 // Add a trailing x1 dim.
122 auto newShape = llvm::to_vector(vecTy.getShape());
123 newShape.push_back(1);
124 auto newTy = VectorType::get(newShape, vecTy.getElementType());
125 return rewriter.create<vector::ShapeCastOp>(loc, newTy, input);
126}
127
128/// Inserts the `source` vector slice into the `dest` vector at offset
129/// `lastOffset` in the last dimension. `source` can be a scalar when `dest` is
130/// a 1D vector.
131static Value insertLastDimSlice(ConversionPatternRewriter &rewriter,
132 Location loc, Value source, Value dest,
133 int64_t lastOffset) {
134 ArrayRef<int64_t> shape = cast<VectorType>(dest.getType()).getShape();
135 assert(lastOffset < shape.back() && "Offset out of bounds");
136
137 // Handle scalar source.
138 if (isa<IntegerType>(source.getType()))
139 return rewriter.create<vector::InsertOp>(loc, source, dest, lastOffset);
140
141 SmallVector<int64_t> offsets(shape.size(), 0);
142 offsets.back() = lastOffset;
143 SmallVector<int64_t> strides(shape.size(), 1);
144 return rewriter.create<vector::InsertStridedSliceOp>(loc, source, dest,
145 offsets, strides);
146}
147
148/// Constructs a new vector of type `resultType` by creating a series of
149/// insertions of `resultComponents`, each at the next offset of the last vector
150/// dimension.
151/// When all `resultComponents` are scalars, the result type is `vector<NxT>`;
152/// when `resultComponents` are `vector<...x1xT>`s, the result type is
153/// `vector<...xNxT>`, where `N` is the number of `resultComponents`.
154static Value constructResultVector(ConversionPatternRewriter &rewriter,
155 Location loc, VectorType resultType,
156 ValueRange resultComponents) {
157 llvm::ArrayRef<int64_t> resultShape = resultType.getShape();
158 (void)resultShape;
159 assert(!resultShape.empty() && "Result expected to have dimensions");
160 assert(resultShape.back() == static_cast<int64_t>(resultComponents.size()) &&
161 "Wrong number of result components");
162
163 Value resultVec = createScalarOrSplatConstant(rewriter, loc, resultType, 0);
164 for (auto [i, component] : llvm::enumerate(First&: resultComponents))
165 resultVec = insertLastDimSlice(rewriter, loc, source: component, dest: resultVec, lastOffset: i);
166
167 return resultVec;
168}
169
170namespace {
171//===----------------------------------------------------------------------===//
172// ConvertConstant
173//===----------------------------------------------------------------------===//
174
175struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> {
176 using OpConversionPattern::OpConversionPattern;
177
178 LogicalResult
179 matchAndRewrite(arith::ConstantOp op, OpAdaptor,
180 ConversionPatternRewriter &rewriter) const override {
181 Type oldType = op.getType();
182 auto newType = getTypeConverter()->convertType<VectorType>(oldType);
183 if (!newType)
184 return rewriter.notifyMatchFailure(
185 op, llvm::formatv("unsupported type: {0}", op.getType()));
186
187 unsigned newBitWidth = newType.getElementTypeBitWidth();
188 Attribute oldValue = op.getValueAttr();
189
190 if (auto intAttr = dyn_cast<IntegerAttr>(oldValue)) {
191 auto [low, high] = getHalves(intAttr.getValue(), newBitWidth);
192 auto newAttr = DenseElementsAttr::get(newType, {low, high});
193 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
194 return success();
195 }
196
197 if (auto splatAttr = dyn_cast<SplatElementsAttr>(oldValue)) {
198 auto [low, high] =
199 getHalves(splatAttr.getSplatValue<APInt>(), newBitWidth);
200 int64_t numSplatElems = splatAttr.getNumElements();
201 SmallVector<APInt> values;
202 values.reserve(N: numSplatElems * 2);
203 for (int64_t i = 0; i < numSplatElems; ++i) {
204 values.push_back(low);
205 values.push_back(high);
206 }
207
208 auto attr = DenseElementsAttr::get(newType, values);
209 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
210 return success();
211 }
212
213 if (auto elemsAttr = dyn_cast<DenseElementsAttr>(oldValue)) {
214 int64_t numElems = elemsAttr.getNumElements();
215 SmallVector<APInt> values;
216 values.reserve(N: numElems * 2);
217 for (const APInt &origVal : elemsAttr.getValues<APInt>()) {
218 auto [low, high] = getHalves(origVal, newBitWidth);
219 values.push_back(std::move(low));
220 values.push_back(std::move(high));
221 }
222
223 auto attr = DenseElementsAttr::get(newType, values);
224 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
225 return success();
226 }
227
228 return rewriter.notifyMatchFailure(op.getLoc(),
229 "unhandled constant attribute");
230 }
231};
232
233//===----------------------------------------------------------------------===//
234// ConvertAddI
235//===----------------------------------------------------------------------===//
236
237struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
238 using OpConversionPattern::OpConversionPattern;
239
240 LogicalResult
241 matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
242 ConversionPatternRewriter &rewriter) const override {
243 Location loc = op->getLoc();
244 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
245 if (!newTy)
246 return rewriter.notifyMatchFailure(
247 loc, llvm::formatv("unsupported type: {0}", op.getType()));
248
249 Type newElemTy = reduceInnermostDim(newTy);
250
251 auto [lhsElem0, lhsElem1] =
252 extractLastDimHalves(rewriter, loc, adaptor.getLhs());
253 auto [rhsElem0, rhsElem1] =
254 extractLastDimHalves(rewriter, loc, adaptor.getRhs());
255
256 auto lowSum =
257 rewriter.create<arith::AddUIExtendedOp>(loc, lhsElem0, rhsElem0);
258 Value overflowVal =
259 rewriter.create<arith::ExtUIOp>(loc, newElemTy, lowSum.getOverflow());
260
261 Value high0 = rewriter.create<arith::AddIOp>(loc, overflowVal, lhsElem1);
262 Value high = rewriter.create<arith::AddIOp>(loc, high0, rhsElem1);
263
264 Value resultVec =
265 constructResultVector(rewriter, loc, newTy, {lowSum.getSum(), high});
266 rewriter.replaceOp(op, resultVec);
267 return success();
268 }
269};
270
271//===----------------------------------------------------------------------===//
272// ConvertBitwiseBinary
273//===----------------------------------------------------------------------===//
274
275/// Conversion pattern template for bitwise binary ops, e.g., `arith.andi`.
276template <typename BinaryOp>
277struct ConvertBitwiseBinary final : OpConversionPattern<BinaryOp> {
278 using OpConversionPattern<BinaryOp>::OpConversionPattern;
279 using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor;
280
281 LogicalResult
282 matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
283 ConversionPatternRewriter &rewriter) const override {
284 Location loc = op->getLoc();
285 auto newTy = this->getTypeConverter()->template convertType<VectorType>(
286 op.getType());
287 if (!newTy)
288 return rewriter.notifyMatchFailure(
289 loc, llvm::formatv("unsupported type: {0}", op.getType()));
290
291 auto [lhsElem0, lhsElem1] =
292 extractLastDimHalves(rewriter, loc, adaptor.getLhs());
293 auto [rhsElem0, rhsElem1] =
294 extractLastDimHalves(rewriter, loc, adaptor.getRhs());
295
296 Value resElem0 = rewriter.create<BinaryOp>(loc, lhsElem0, rhsElem0);
297 Value resElem1 = rewriter.create<BinaryOp>(loc, lhsElem1, rhsElem1);
298 Value resultVec =
299 constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
300 rewriter.replaceOp(op, resultVec);
301 return success();
302 }
303};
304
305//===----------------------------------------------------------------------===//
306// ConvertCmpI
307//===----------------------------------------------------------------------===//
308
309/// Returns the matching unsigned version of the given predicate `pred`, or the
310/// same predicate if `pred` is not a signed.
311static arith::CmpIPredicate toUnsignedPredicate(arith::CmpIPredicate pred) {
312 using P = arith::CmpIPredicate;
313 switch (pred) {
314 case P::sge:
315 return P::uge;
316 case P::sgt:
317 return P::ugt;
318 case P::sle:
319 return P::ule;
320 case P::slt:
321 return P::ult;
322 default:
323 return pred;
324 }
325}
326
327struct ConvertCmpI final : OpConversionPattern<arith::CmpIOp> {
328 using OpConversionPattern::OpConversionPattern;
329
330 LogicalResult
331 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
332 ConversionPatternRewriter &rewriter) const override {
333 Location loc = op->getLoc();
334 auto inputTy =
335 getTypeConverter()->convertType<VectorType>(op.getLhs().getType());
336 if (!inputTy)
337 return rewriter.notifyMatchFailure(
338 loc, llvm::formatv("unsupported type: {0}", op.getType()));
339
340 arith::CmpIPredicate highPred = adaptor.getPredicate();
341 arith::CmpIPredicate lowPred = toUnsignedPredicate(highPred);
342
343 auto [lhsElem0, lhsElem1] =
344 extractLastDimHalves(rewriter, loc, adaptor.getLhs());
345 auto [rhsElem0, rhsElem1] =
346 extractLastDimHalves(rewriter, loc, adaptor.getRhs());
347
348 Value lowCmp =
349 rewriter.create<arith::CmpIOp>(loc, lowPred, lhsElem0, rhsElem0);
350 Value highCmp =
351 rewriter.create<arith::CmpIOp>(loc, highPred, lhsElem1, rhsElem1);
352
353 Value cmpResult{};
354 switch (highPred) {
355 case arith::CmpIPredicate::eq: {
356 cmpResult = rewriter.create<arith::AndIOp>(loc, lowCmp, highCmp);
357 break;
358 }
359 case arith::CmpIPredicate::ne: {
360 cmpResult = rewriter.create<arith::OrIOp>(loc, lowCmp, highCmp);
361 break;
362 }
363 default: {
364 // Handle inequality checks.
365 Value highEq = rewriter.create<arith::CmpIOp>(
366 loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1);
367 cmpResult =
368 rewriter.create<arith::SelectOp>(loc, highEq, lowCmp, highCmp);
369 break;
370 }
371 }
372
373 assert(cmpResult && "Unhandled case");
374 rewriter.replaceOp(op, dropTrailingX1Dim(rewriter, loc, input: cmpResult));
375 return success();
376 }
377};
378
379//===----------------------------------------------------------------------===//
380// ConvertMulI
381//===----------------------------------------------------------------------===//
382
383struct ConvertMulI final : OpConversionPattern<arith::MulIOp> {
384 using OpConversionPattern::OpConversionPattern;
385
386 LogicalResult
387 matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor,
388 ConversionPatternRewriter &rewriter) const override {
389 Location loc = op->getLoc();
390 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
391 if (!newTy)
392 return rewriter.notifyMatchFailure(
393 loc, llvm::formatv("unsupported type: {0}", op.getType()));
394
395 auto [lhsElem0, lhsElem1] =
396 extractLastDimHalves(rewriter, loc, adaptor.getLhs());
397 auto [rhsElem0, rhsElem1] =
398 extractLastDimHalves(rewriter, loc, adaptor.getRhs());
399
400 // The multiplication algorithm used is the standard (long) multiplication.
401 // Multiplying two i2N integers produces (at most) an i4N result, but
402 // because the calculation of top i2N is not necessary, we omit it.
403 auto mulLowLow =
404 rewriter.create<arith::MulUIExtendedOp>(loc, lhsElem0, rhsElem0);
405 Value mulLowHi = rewriter.create<arith::MulIOp>(loc, lhsElem0, rhsElem1);
406 Value mulHiLow = rewriter.create<arith::MulIOp>(loc, lhsElem1, rhsElem0);
407
408 Value resLow = mulLowLow.getLow();
409 Value resHi =
410 rewriter.create<arith::AddIOp>(loc, mulLowLow.getHigh(), mulLowHi);
411 resHi = rewriter.create<arith::AddIOp>(loc, resHi, mulHiLow);
412
413 Value resultVec =
414 constructResultVector(rewriter, loc, newTy, {resLow, resHi});
415 rewriter.replaceOp(op, resultVec);
416 return success();
417 }
418};
419
420//===----------------------------------------------------------------------===//
421// ConvertExtSI
422//===----------------------------------------------------------------------===//
423
424struct ConvertExtSI final : OpConversionPattern<arith::ExtSIOp> {
425 using OpConversionPattern::OpConversionPattern;
426
427 LogicalResult
428 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
429 ConversionPatternRewriter &rewriter) const override {
430 Location loc = op->getLoc();
431 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
432 if (!newTy)
433 return rewriter.notifyMatchFailure(
434 loc, llvm::formatv("unsupported type: {0}", op.getType()));
435
436 Type newResultComponentTy = reduceInnermostDim(newTy);
437
438 // Sign-extend the input value to determine the low half of the result.
439 // Then, check if the low half is negative, and sign-extend the comparison
440 // result to get the high half.
441 Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
442 Value extended = rewriter.createOrFold<arith::ExtSIOp>(
443 loc, newResultComponentTy, newOperand);
444 Value operandZeroCst =
445 createScalarOrSplatConstant(builder&: rewriter, loc, type: newResultComponentTy, value: 0);
446 Value signBit = rewriter.create<arith::CmpIOp>(
447 loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
448 Value signValue =
449 rewriter.create<arith::ExtSIOp>(loc, newResultComponentTy, signBit);
450
451 Value resultVec =
452 constructResultVector(rewriter, loc, newTy, {extended, signValue});
453 rewriter.replaceOp(op, resultVec);
454 return success();
455 }
456};
457
458//===----------------------------------------------------------------------===//
459// ConvertExtUI
460//===----------------------------------------------------------------------===//
461
462struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> {
463 using OpConversionPattern::OpConversionPattern;
464
465 LogicalResult
466 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
467 ConversionPatternRewriter &rewriter) const override {
468 Location loc = op->getLoc();
469 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
470 if (!newTy)
471 return rewriter.notifyMatchFailure(
472 loc, llvm::formatv("unsupported type: {0}", op.getType()));
473
474 Type newResultComponentTy = reduceInnermostDim(newTy);
475
476 // Zero-extend the input value to determine the low half of the result.
477 // The high half is always zero.
478 Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
479 Value extended = rewriter.createOrFold<arith::ExtUIOp>(
480 loc, newResultComponentTy, newOperand);
481 Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newTy, 0);
482 Value newRes = insertLastDimSlice(rewriter, loc, source: extended, dest: zeroCst, lastOffset: 0);
483 rewriter.replaceOp(op, newRes);
484 return success();
485 }
486};
487
488//===----------------------------------------------------------------------===//
489// ConvertMaxMin
490//===----------------------------------------------------------------------===//
491
492template <typename SourceOp, arith::CmpIPredicate CmpPred>
493struct ConvertMaxMin final : OpConversionPattern<SourceOp> {
494 using OpConversionPattern<SourceOp>::OpConversionPattern;
495
496 LogicalResult
497 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
498 ConversionPatternRewriter &rewriter) const override {
499 Location loc = op->getLoc();
500
501 Type oldTy = op.getType();
502 auto newTy = dyn_cast_or_null<VectorType>(
503 this->getTypeConverter()->convertType(oldTy));
504 if (!newTy)
505 return rewriter.notifyMatchFailure(
506 loc, llvm::formatv("unsupported type: {0}", op.getType()));
507
508 // Rewrite Max*I/Min*I as compare and select over original operands. Let
509 // the CmpI and Select emulation patterns handle the final legalization.
510 Value cmp =
511 rewriter.create<arith::CmpIOp>(loc, CmpPred, op.getLhs(), op.getRhs());
512 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, op.getLhs(),
513 op.getRhs());
514 return success();
515 }
516};
517
518// Convert IndexCast ops
519//===----------------------------------------------------------------------===//
520
521/// Returns true iff the type is `index` or `vector<...index>`.
522static bool isIndexOrIndexVector(Type type) {
523 if (isa<IndexType>(Val: type))
524 return true;
525
526 if (auto vectorTy = dyn_cast<VectorType>(type))
527 if (isa<IndexType>(vectorTy.getElementType()))
528 return true;
529
530 return false;
531}
532
533template <typename CastOp>
534struct ConvertIndexCastIntToIndex final : OpConversionPattern<CastOp> {
535 using OpConversionPattern<CastOp>::OpConversionPattern;
536
537 LogicalResult
538 matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
539 ConversionPatternRewriter &rewriter) const override {
540 Type resultType = op.getType();
541 if (!isIndexOrIndexVector(type: resultType))
542 return failure();
543
544 Location loc = op.getLoc();
545 Type inType = op.getIn().getType();
546 auto newInTy =
547 this->getTypeConverter()->template convertType<VectorType>(inType);
548 if (!newInTy)
549 return rewriter.notifyMatchFailure(
550 arg&: loc, msg: llvm::formatv(Fmt: "unsupported type: {0}", Vals&: inType));
551
552 // Discard the high half of the input truncating the original value.
553 Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
554 extracted = dropTrailingX1Dim(rewriter, loc, input: extracted);
555 rewriter.replaceOpWithNewOp<CastOp>(op, resultType, extracted);
556 return success();
557 }
558};
559
560template <typename CastOp, typename ExtensionOp>
561struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
562 using OpConversionPattern<CastOp>::OpConversionPattern;
563
564 LogicalResult
565 matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
566 ConversionPatternRewriter &rewriter) const override {
567 Type inType = op.getIn().getType();
568 if (!isIndexOrIndexVector(type: inType))
569 return failure();
570
571 Location loc = op.getLoc();
572 auto *typeConverter =
573 this->template getTypeConverter<arith::WideIntEmulationConverter>();
574
575 Type resultType = op.getType();
576 auto newTy = typeConverter->template convertType<VectorType>(resultType);
577 if (!newTy)
578 return rewriter.notifyMatchFailure(
579 arg&: loc, msg: llvm::formatv(Fmt: "unsupported type: {0}", Vals&: resultType));
580
581 // Emit an index cast over the matching narrow type.
582 Type narrowTy =
583 rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth());
584 if (auto vecTy = dyn_cast<VectorType>(resultType))
585 narrowTy = VectorType::get(vecTy.getShape(), narrowTy);
586
587 // Sign or zero-extend the result. Let the matching conversion pattern
588 // legalize the extension op.
589 Value underlyingVal =
590 rewriter.create<CastOp>(loc, narrowTy, adaptor.getIn());
591 rewriter.replaceOpWithNewOp<ExtensionOp>(op, resultType, underlyingVal);
592 return success();
593 }
594};
595
596//===----------------------------------------------------------------------===//
597// ConvertSelect
598//===----------------------------------------------------------------------===//
599
600struct ConvertSelect final : OpConversionPattern<arith::SelectOp> {
601 using OpConversionPattern::OpConversionPattern;
602
603 LogicalResult
604 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
605 ConversionPatternRewriter &rewriter) const override {
606 Location loc = op->getLoc();
607 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
608 if (!newTy)
609 return rewriter.notifyMatchFailure(
610 loc, llvm::formatv("unsupported type: {0}", op.getType()));
611
612 auto [trueElem0, trueElem1] =
613 extractLastDimHalves(rewriter, loc, adaptor.getTrueValue());
614 auto [falseElem0, falseElem1] =
615 extractLastDimHalves(rewriter, loc, adaptor.getFalseValue());
616 Value cond = appendX1Dim(rewriter, loc, adaptor.getCondition());
617
618 Value resElem0 =
619 rewriter.create<arith::SelectOp>(loc, cond, trueElem0, falseElem0);
620 Value resElem1 =
621 rewriter.create<arith::SelectOp>(loc, cond, trueElem1, falseElem1);
622 Value resultVec =
623 constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
624 rewriter.replaceOp(op, resultVec);
625 return success();
626 }
627};
628
629//===----------------------------------------------------------------------===//
630// ConvertShLI
631//===----------------------------------------------------------------------===//
632
633struct ConvertShLI final : OpConversionPattern<arith::ShLIOp> {
634 using OpConversionPattern::OpConversionPattern;
635
636 LogicalResult
637 matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor,
638 ConversionPatternRewriter &rewriter) const override {
639 Location loc = op->getLoc();
640
641 Type oldTy = op.getType();
642 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
643 if (!newTy)
644 return rewriter.notifyMatchFailure(
645 loc, llvm::formatv("unsupported type: {0}", op.getType()));
646
647 Type newOperandTy = reduceInnermostDim(newTy);
648 // `oldBitWidth` == `2 * newBitWidth`
649 unsigned newBitWidth = newTy.getElementTypeBitWidth();
650
651 auto [lhsElem0, lhsElem1] =
652 extractLastDimHalves(rewriter, loc, adaptor.getLhs());
653 Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
654
655 // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and
656 // high halves of the results separately:
657 // 1. low := LHS.low shli RHS
658 //
659 // 2. high := a or b or c, where:
660 // a) Bits from LHS.high, shifted by the RHS.
661 // b) Bits from LHS.low, shifted right. These come into play when
662 // RHS < newBitWidth, e.g.:
663 // [0000][llll] shli 3 --> [0lll][l000]
664 // ^
665 // |
666 // [llll] shrui (4 - 3)
667 // c) Bits from LHS.low, shifted left. These matter when
668 // RHS > newBitWidth, e.g.:
669 // [0000][llll] shli 7 --> [l000][0000]
670 // ^
671 // |
672 // [llll] shli (7 - 4)
673 //
674 // Because shifts by values >= newBitWidth are undefined, we ignore the high
675 // half of RHS, and introduce 'bounds checks' to account for
676 // RHS.low > newBitWidth.
677 //
678 // TODO: Explore possible optimizations.
679 Value zeroCst = createScalarOrSplatConstant(builder&: rewriter, loc, type: newOperandTy, value: 0);
680 Value elemBitWidth =
681 createScalarOrSplatConstant(builder&: rewriter, loc, type: newOperandTy, value: newBitWidth);
682
683 Value illegalElemShift = rewriter.create<arith::CmpIOp>(
684 loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
685
686 Value shiftedElem0 =
687 rewriter.create<arith::ShLIOp>(loc, lhsElem0, rhsElem0);
688 Value resElem0 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
689 zeroCst, shiftedElem0);
690
691 Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
692 loc, illegalElemShift, elemBitWidth, rhsElem0);
693 Value rightShiftAmount =
694 rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
695 Value shiftedRight =
696 rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rightShiftAmount);
697 Value overshotShiftAmount =
698 rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
699 Value shiftedLeft =
700 rewriter.create<arith::ShLIOp>(loc, lhsElem0, overshotShiftAmount);
701
702 Value shiftedElem1 =
703 rewriter.create<arith::ShLIOp>(loc, lhsElem1, rhsElem0);
704 Value resElem1High = rewriter.create<arith::SelectOp>(
705 loc, illegalElemShift, zeroCst, shiftedElem1);
706 Value resElem1Low = rewriter.create<arith::SelectOp>(
707 loc, illegalElemShift, shiftedLeft, shiftedRight);
708 Value resElem1 =
709 rewriter.create<arith::OrIOp>(loc, resElem1Low, resElem1High);
710
711 Value resultVec =
712 constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
713 rewriter.replaceOp(op, resultVec);
714 return success();
715 }
716};
717
718//===----------------------------------------------------------------------===//
719// ConvertShRUI
720//===----------------------------------------------------------------------===//
721
722struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> {
723 using OpConversionPattern::OpConversionPattern;
724
725 LogicalResult
726 matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor,
727 ConversionPatternRewriter &rewriter) const override {
728 Location loc = op->getLoc();
729
730 Type oldTy = op.getType();
731 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
732 if (!newTy)
733 return rewriter.notifyMatchFailure(
734 loc, llvm::formatv("unsupported type: {0}", op.getType()));
735
736 Type newOperandTy = reduceInnermostDim(newTy);
737 // `oldBitWidth` == `2 * newBitWidth`
738 unsigned newBitWidth = newTy.getElementTypeBitWidth();
739
740 auto [lhsElem0, lhsElem1] =
741 extractLastDimHalves(rewriter, loc, adaptor.getLhs());
742 Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
743
744 // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and
745 // high halves of the results separately:
746 // 1. low := a or b or c, where:
747 // a) Bits from LHS.low, shifted by the RHS.
748 // b) Bits from LHS.high, shifted left. These matter when
749 // RHS < newBitWidth, e.g.:
750 // [hhhh][0000] shrui 3 --> [000h][hhh0]
751 // ^
752 // |
753 // [hhhh] shli (4 - 1)
754 // c) Bits from LHS.high, shifted right. These come into play when
755 // RHS > newBitWidth, e.g.:
756 // [hhhh][0000] shrui 7 --> [0000][000h]
757 // ^
758 // |
759 // [hhhh] shrui (7 - 4)
760 //
761 // 2. high := LHS.high shrui RHS
762 //
763 // Because shifts by values >= newBitWidth are undefined, we ignore the high
764 // half of RHS, and introduce 'bounds checks' to account for
765 // RHS.low > newBitWidth.
766 //
767 // TODO: Explore possible optimizations.
768 Value zeroCst = createScalarOrSplatConstant(builder&: rewriter, loc, type: newOperandTy, value: 0);
769 Value elemBitWidth =
770 createScalarOrSplatConstant(builder&: rewriter, loc, type: newOperandTy, value: newBitWidth);
771
772 Value illegalElemShift = rewriter.create<arith::CmpIOp>(
773 loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
774
775 Value shiftedElem0 =
776 rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rhsElem0);
777 Value resElem0Low = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
778 zeroCst, shiftedElem0);
779 Value shiftedElem1 =
780 rewriter.create<arith::ShRUIOp>(loc, lhsElem1, rhsElem0);
781 Value resElem1 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
782 zeroCst, shiftedElem1);
783
784 Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
785 loc, illegalElemShift, elemBitWidth, rhsElem0);
786 Value leftShiftAmount =
787 rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
788 Value shiftedLeft =
789 rewriter.create<arith::ShLIOp>(loc, lhsElem1, leftShiftAmount);
790 Value overshotShiftAmount =
791 rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
792 Value shiftedRight =
793 rewriter.create<arith::ShRUIOp>(loc, lhsElem1, overshotShiftAmount);
794
795 Value resElem0High = rewriter.create<arith::SelectOp>(
796 loc, illegalElemShift, shiftedRight, shiftedLeft);
797 Value resElem0 =
798 rewriter.create<arith::OrIOp>(loc, resElem0Low, resElem0High);
799
800 Value resultVec =
801 constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
802 rewriter.replaceOp(op, resultVec);
803 return success();
804 }
805};
806
807//===----------------------------------------------------------------------===//
808// ConvertShRSI
809//===----------------------------------------------------------------------===//
810
811struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
812 using OpConversionPattern::OpConversionPattern;
813
814 LogicalResult
815 matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
816 ConversionPatternRewriter &rewriter) const override {
817 Location loc = op->getLoc();
818
819 Type oldTy = op.getType();
820 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
821 if (!newTy)
822 return rewriter.notifyMatchFailure(
823 loc, llvm::formatv("unsupported type: {0}", op.getType()));
824
825 Value lhsElem1 = extractLastDimSlice(rewriter, loc, adaptor.getLhs(), 1);
826 Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
827
828 Type narrowTy = rhsElem0.getType();
829 int64_t origBitwidth = newTy.getElementTypeBitWidth() * 2;
830
831 // Rewrite this as an bitwise or of `arith.shrui` and sign extension bits.
832 // Perform as many ops over the narrow integer type as possible and let the
833 // other emulation patterns convert the rest.
834 Value elemZero = createScalarOrSplatConstant(builder&: rewriter, loc, type: narrowTy, value: 0);
835 Value signBit = rewriter.create<arith::CmpIOp>(
836 loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
837 signBit = dropTrailingX1Dim(rewriter, loc, input: signBit);
838
839 // Create a bit pattern of either all ones or all zeros. Then shift it left
840 // to calculate the sign extension bits created by shifting the original
841 // sign bit right.
842 Value allSign = rewriter.create<arith::ExtSIOp>(loc, oldTy, signBit);
843 Value maxShift =
844 createScalarOrSplatConstant(builder&: rewriter, loc, type: narrowTy, value: origBitwidth);
845 Value numNonSignExtBits =
846 rewriter.create<arith::SubIOp>(loc, maxShift, rhsElem0);
847 numNonSignExtBits = dropTrailingX1Dim(rewriter, loc, input: numNonSignExtBits);
848 numNonSignExtBits =
849 rewriter.create<arith::ExtUIOp>(loc, oldTy, numNonSignExtBits);
850 Value signBits =
851 rewriter.create<arith::ShLIOp>(loc, allSign, numNonSignExtBits);
852
853 // Use original arguments to create the right shift.
854 Value shrui =
855 rewriter.create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs());
856 Value shrsi = rewriter.create<arith::OrIOp>(loc, shrui, signBits);
857
858 // Handle shifting by zero. This is necessary when the `signBits` shift is
859 // invalid.
860 Value isNoop = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
861 rhsElem0, elemZero);
862 isNoop = dropTrailingX1Dim(rewriter, loc, input: isNoop);
863 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(),
864 shrsi);
865
866 return success();
867 }
868};
869
870//===----------------------------------------------------------------------===//
871// ConvertSubI
872//===----------------------------------------------------------------------===//
873
874struct ConvertSubI final : OpConversionPattern<arith::SubIOp> {
875 using OpConversionPattern::OpConversionPattern;
876
877 LogicalResult
878 matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor,
879 ConversionPatternRewriter &rewriter) const override {
880 Location loc = op->getLoc();
881 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
882 if (!newTy)
883 return rewriter.notifyMatchFailure(
884 loc, llvm::formatv("unsupported type: {}", op.getType()));
885
886 Type newElemTy = reduceInnermostDim(newTy);
887
888 auto [lhsElem0, lhsElem1] =
889 extractLastDimHalves(rewriter, loc, adaptor.getLhs());
890 auto [rhsElem0, rhsElem1] =
891 extractLastDimHalves(rewriter, loc, adaptor.getRhs());
892
893 // Emulates LHS - RHS by [LHS0 - RHS0, LHS1 - RHS1 - CARRY] where
894 // CARRY is 1 or 0.
895 Value low = rewriter.create<arith::SubIOp>(loc, lhsElem0, rhsElem0);
896 // We have a carry if lhsElem0 < rhsElem0.
897 Value carry0 = rewriter.create<arith::CmpIOp>(
898 loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0);
899 Value carryVal = rewriter.create<arith::ExtUIOp>(loc, newElemTy, carry0);
900
901 Value high0 = rewriter.create<arith::SubIOp>(loc, lhsElem1, carryVal);
902 Value high = rewriter.create<arith::SubIOp>(loc, high0, rhsElem1);
903
904 Value resultVec = constructResultVector(rewriter, loc, newTy, {low, high});
905 rewriter.replaceOp(op, resultVec);
906 return success();
907 }
908};
909
910//===----------------------------------------------------------------------===//
911// ConvertSIToFP
912//===----------------------------------------------------------------------===//
913
914struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> {
915 using OpConversionPattern::OpConversionPattern;
916
917 LogicalResult
918 matchAndRewrite(arith::SIToFPOp op, OpAdaptor adaptor,
919 ConversionPatternRewriter &rewriter) const override {
920 Location loc = op.getLoc();
921
922 Value in = op.getIn();
923 Type oldTy = in.getType();
924 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
925 if (!newTy)
926 return rewriter.notifyMatchFailure(
927 arg&: loc, msg: llvm::formatv(Fmt: "unsupported type: {0}", Vals&: oldTy));
928
929 Value zeroCst = createScalarOrSplatConstant(builder&: rewriter, loc, type: oldTy, value: 0);
930
931 // To avoid operating on very large unsigned numbers, perform the
932 // conversion on the absolute value. Then, decide whether to negate the
933 // result or not based on that sign bit. We implement negation by
934 // subtracting from zero. Note that this relies on the the other conversion
935 // patterns to legalize created ops and narrow the bit widths.
936 Value isNeg = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
937 in, zeroCst);
938 Value neg = rewriter.create<arith::SubIOp>(loc, zeroCst, in);
939 Value abs = rewriter.create<arith::SelectOp>(loc, isNeg, neg, in);
940
941 Value absResult = rewriter.create<arith::UIToFPOp>(loc, op.getType(), abs);
942 Value negResult = rewriter.create<arith::NegFOp>(loc, absResult);
943 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, negResult,
944 absResult);
945 return success();
946 }
947};
948
949//===----------------------------------------------------------------------===//
950// ConvertUIToFP
951//===----------------------------------------------------------------------===//
952
953struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
954 using OpConversionPattern::OpConversionPattern;
955
956 LogicalResult
957 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
958 ConversionPatternRewriter &rewriter) const override {
959 Location loc = op.getLoc();
960
961 Type oldTy = op.getIn().getType();
962 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
963 if (!newTy)
964 return rewriter.notifyMatchFailure(
965 arg&: loc, msg: llvm::formatv(Fmt: "unsupported type: {0}", Vals&: oldTy));
966 unsigned newBitWidth = newTy.getElementTypeBitWidth();
967
968 auto [low, hi] = extractLastDimHalves(rewriter, loc, adaptor.getIn());
969 Value lowInt = dropTrailingX1Dim(rewriter, loc, low);
970 Value hiInt = dropTrailingX1Dim(rewriter, loc, hi);
971 Value zeroCst =
972 createScalarOrSplatConstant(builder&: rewriter, loc, type: hiInt.getType(), value: 0);
973
974 // The final result has the following form:
975 // if (hi == 0) return uitofp(low)
976 // else return uitofp(low) + uitofp(hi) * 2^BW
977 //
978 // where `BW` is the bitwidth of the narrowed integer type. We emit a
979 // select to make it easier to fold-away the `hi` part calculation when it
980 // is known to be zero.
981 //
982 // Note 1: The emulation is precise only for input values that have exact
983 // integer representation in the result floating point type, and may lead
984 // loss of precision otherwise.
985 //
986 // Note 2: We do not strictly need the `hi == 0`, case, but it makes
987 // constant folding easier.
988 Value hiEqZero = rewriter.create<arith::CmpIOp>(
989 loc, arith::CmpIPredicate::eq, hiInt, zeroCst);
990
991 Type resultTy = op.getType();
992 Type resultElemTy = getElementTypeOrSelf(type: resultTy);
993 Value lowFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, lowInt);
994 Value hiFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, hiInt);
995
996 int64_t pow2Int = int64_t(1) << newBitWidth;
997 TypedAttr pow2Attr =
998 rewriter.getFloatAttr(resultElemTy, static_cast<double>(pow2Int));
999 if (auto vecTy = dyn_cast<VectorType>(resultTy))
1000 pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr);
1001
1002 Value pow2Val = rewriter.create<arith::ConstantOp>(loc, resultTy, pow2Attr);
1003
1004 Value hiVal = rewriter.create<arith::MulFOp>(loc, hiFp, pow2Val);
1005 Value result = rewriter.create<arith::AddFOp>(loc, lowFp, hiVal);
1006
1007 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, hiEqZero, lowFp, result);
1008 return success();
1009 }
1010};
1011
1012//===----------------------------------------------------------------------===//
1013// ConvertFPToSI
1014//===----------------------------------------------------------------------===//
1015
1016struct ConvertFPToSI final : OpConversionPattern<arith::FPToSIOp> {
1017 using OpConversionPattern::OpConversionPattern;
1018
1019 LogicalResult
1020 matchAndRewrite(arith::FPToSIOp op, OpAdaptor adaptor,
1021 ConversionPatternRewriter &rewriter) const override {
1022 Location loc = op.getLoc();
1023 // Get the input float type.
1024 Value inFp = adaptor.getIn();
1025 Type fpTy = inFp.getType();
1026
1027 Type intTy = op.getType();
1028
1029 auto newTy = getTypeConverter()->convertType<VectorType>(intTy);
1030 if (!newTy)
1031 return rewriter.notifyMatchFailure(
1032 arg&: loc, msg: llvm::formatv(Fmt: "unsupported type: {}", Vals&: intTy));
1033
1034 // Work on the absolute value and then convert the result to signed integer.
1035 // Defer absolute value to fptoui. If minSInt < fp < maxSInt, i.e. if the fp
1036 // is representable in signed i2N, emits the correct result. Else, the
1037 // result is UB.
1038
1039 TypedAttr zeroAttr = rewriter.getZeroAttr(fpTy);
1040 Value zeroCst = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
1041 Value zeroCstInt = createScalarOrSplatConstant(builder&: rewriter, loc, type: intTy, value: 0);
1042
1043 // Get the absolute value. One could have used math.absf here, but that
1044 // introduces an extra dependency.
1045 Value isNeg = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
1046 inFp, zeroCst);
1047 Value negInFp = rewriter.create<arith::NegFOp>(loc, inFp);
1048
1049 Value absVal = rewriter.create<arith::SelectOp>(loc, isNeg, negInFp, inFp);
1050
1051 // Defer the absolute value to fptoui.
1052 Value res = rewriter.create<arith::FPToUIOp>(loc, intTy, absVal);
1053
1054 // Negate the value if < 0 .
1055 Value neg = rewriter.create<arith::SubIOp>(loc, zeroCstInt, res);
1056
1057 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, neg, res);
1058 return success();
1059 }
1060};
1061
1062//===----------------------------------------------------------------------===//
1063// ConvertFPToUI
1064//===----------------------------------------------------------------------===//
1065
1066struct ConvertFPToUI final : OpConversionPattern<arith::FPToUIOp> {
1067 using OpConversionPattern::OpConversionPattern;
1068
1069 LogicalResult
1070 matchAndRewrite(arith::FPToUIOp op, OpAdaptor adaptor,
1071 ConversionPatternRewriter &rewriter) const override {
1072 Location loc = op.getLoc();
1073 // Get the input float type.
1074 Value inFp = adaptor.getIn();
1075 Type fpTy = inFp.getType();
1076
1077 Type intTy = op.getType();
1078 auto newTy = getTypeConverter()->convertType<VectorType>(intTy);
1079 if (!newTy)
1080 return rewriter.notifyMatchFailure(
1081 arg&: loc, msg: llvm::formatv(Fmt: "unsupported type: {}", Vals&: intTy));
1082 unsigned newBitWidth = newTy.getElementTypeBitWidth();
1083
1084 Type newHalfType = IntegerType::get(inFp.getContext(), newBitWidth);
1085 if (auto vecType = dyn_cast<VectorType>(fpTy))
1086 newHalfType = VectorType::get(vecType.getShape(), newHalfType);
1087
1088 // The resulting integer has the upper part and the lower part. This would
1089 // be interpreted as 2^N * high + low, where N is the bitwidth. Therefore,
1090 // to calculate the higher part, we emit resHigh = fptoui(fp/2^N). For the
1091 // lower part, we emit fptoui(fp - resHigh * 2^N). The special cases of
1092 // overflows including +-inf, NaNs and negative numbers are UB.
1093
1094 const llvm::fltSemantics &fSemantics =
1095 cast<FloatType>(getElementTypeOrSelf(type: fpTy)).getFloatSemantics();
1096
1097 auto powBitwidth = llvm::APFloat(fSemantics);
1098 // If the integer does not fit the floating point number, we set the
1099 // powBitwidth to inf. This ensures that the upper part is set
1100 // correctly to 0. The opStatus inexact here only occurs when we have an
1101 // overflow, since the number is always a power of two.
1102 if (powBitwidth.convertFromAPInt(APInt(newBitWidth * 2, 1).shl(shiftAmt: newBitWidth),
1103 false, llvm::RoundingMode::TowardZero) ==
1104 llvm::detail::opStatus::opInexact)
1105 powBitwidth = llvm::APFloat::getInf(Sem: fSemantics);
1106
1107 TypedAttr powBitwidthAttr =
1108 FloatAttr::get(getElementTypeOrSelf(fpTy), powBitwidth);
1109 if (auto vecType = dyn_cast<VectorType>(fpTy))
1110 powBitwidthAttr = SplatElementsAttr::get(vecType, powBitwidthAttr);
1111 Value powBitwidthFloatCst =
1112 rewriter.create<arith::ConstantOp>(loc, powBitwidthAttr);
1113
1114 Value fpDivPowBitwidth =
1115 rewriter.create<arith::DivFOp>(loc, inFp, powBitwidthFloatCst);
1116 Value resHigh =
1117 rewriter.create<arith::FPToUIOp>(loc, newHalfType, fpDivPowBitwidth);
1118 // Calculate fp - resHigh * 2^N by getting the remainder of the division
1119 Value remainder =
1120 rewriter.create<arith::RemFOp>(loc, inFp, powBitwidthFloatCst);
1121 Value resLow =
1122 rewriter.create<arith::FPToUIOp>(loc, newHalfType, remainder);
1123
1124 Value high = appendX1Dim(rewriter, loc, input: resHigh);
1125 Value low = appendX1Dim(rewriter, loc, input: resLow);
1126
1127 Value resultVec = constructResultVector(rewriter, loc, newTy, {low, high});
1128
1129 rewriter.replaceOp(op, resultVec);
1130 return success();
1131 }
1132};
1133
1134//===----------------------------------------------------------------------===//
1135// ConvertTruncI
1136//===----------------------------------------------------------------------===//
1137
1138struct ConvertTruncI final : OpConversionPattern<arith::TruncIOp> {
1139 using OpConversionPattern::OpConversionPattern;
1140
1141 LogicalResult
1142 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
1143 ConversionPatternRewriter &rewriter) const override {
1144 Location loc = op.getLoc();
1145 // Check if the result type is legal for this target. Currently, we do not
1146 // support truncation to types wider than supported by the target.
1147 if (!getTypeConverter()->isLegal(op.getType()))
1148 return rewriter.notifyMatchFailure(
1149 loc, llvm::formatv("unsupported truncation result type: {0}",
1150 op.getType()));
1151
1152 // Discard the high half of the input. Truncate the low half, if
1153 // necessary.
1154 Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
1155 extracted = dropTrailingX1Dim(rewriter, loc, input: extracted);
1156 Value truncated =
1157 rewriter.createOrFold<arith::TruncIOp>(loc, op.getType(), extracted);
1158 rewriter.replaceOp(op, truncated);
1159 return success();
1160 }
1161};
1162
1163//===----------------------------------------------------------------------===//
1164// ConvertVectorPrint
1165//===----------------------------------------------------------------------===//
1166
1167struct ConvertVectorPrint final : OpConversionPattern<vector::PrintOp> {
1168 using OpConversionPattern::OpConversionPattern;
1169
1170 LogicalResult
1171 matchAndRewrite(vector::PrintOp op, OpAdaptor adaptor,
1172 ConversionPatternRewriter &rewriter) const override {
1173 rewriter.replaceOpWithNewOp<vector::PrintOp>(op, adaptor.getSource());
1174 return success();
1175 }
1176};
1177
1178//===----------------------------------------------------------------------===//
1179// Pass Definition
1180//===----------------------------------------------------------------------===//
1181
1182struct EmulateWideIntPass final
1183 : arith::impl::ArithEmulateWideIntBase<EmulateWideIntPass> {
1184 using ArithEmulateWideIntBase::ArithEmulateWideIntBase;
1185
1186 void runOnOperation() override {
1187 if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
1188 signalPassFailure();
1189 return;
1190 }
1191
1192 Operation *op = getOperation();
1193 MLIRContext *ctx = op->getContext();
1194
1195 arith::WideIntEmulationConverter typeConverter(widestIntSupported);
1196 ConversionTarget target(*ctx);
1197 target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
1198 return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
1199 });
1200 auto opLegalCallback = [&typeConverter](Operation *op) {
1201 return typeConverter.isLegal(op);
1202 };
1203 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
1204 target.addDynamicallyLegalOp<vector::PrintOp>(opLegalCallback);
1205 target.addDynamicallyLegalDialect<arith::ArithDialect>(opLegalCallback);
1206 target.addLegalDialect<vector::VectorDialect>();
1207
1208 RewritePatternSet patterns(ctx);
1209 arith::populateArithWideIntEmulationPatterns(typeConverter: typeConverter, patterns);
1210
1211 // Populate `func.*` conversion patterns.
1212 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
1213 patterns, typeConverter);
1214 populateCallOpTypeConversionPattern(patterns, typeConverter);
1215 populateReturnOpTypeConversionPattern(patterns, typeConverter);
1216
1217 if (failed(applyPartialConversion(op, target, std::move(patterns))))
1218 signalPassFailure();
1219 }
1220};
1221} // end anonymous namespace
1222
1223//===----------------------------------------------------------------------===//
1224// Public Interface Definition
1225//===----------------------------------------------------------------------===//
1226
1227arith::WideIntEmulationConverter::WideIntEmulationConverter(
1228 unsigned widestIntSupportedByTarget)
1229 : maxIntWidth(widestIntSupportedByTarget) {
1230 assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
1231 "Only power-of-two integers with are supported");
1232 assert(widestIntSupportedByTarget >= 2 && "Integer type too narrow");
1233
1234 // Allow unknown types.
1235 addConversion(callback: [](Type ty) -> std::optional<Type> { return ty; });
1236
1237 // Scalar case.
1238 addConversion(callback: [this](IntegerType ty) -> std::optional<Type> {
1239 unsigned width = ty.getWidth();
1240 if (width <= maxIntWidth)
1241 return ty;
1242
1243 // i2N --> vector<2xiN>
1244 if (width == 2 * maxIntWidth)
1245 return VectorType::get(2, IntegerType::get(ty.getContext(), maxIntWidth));
1246
1247 return nullptr;
1248 });
1249
1250 // Vector case.
1251 addConversion(callback: [this](VectorType ty) -> std::optional<Type> {
1252 auto intTy = dyn_cast<IntegerType>(ty.getElementType());
1253 if (!intTy)
1254 return ty;
1255
1256 unsigned width = intTy.getWidth();
1257 if (width <= maxIntWidth)
1258 return ty;
1259
1260 // vector<...xi2N> --> vector<...x2xiN>
1261 if (width == 2 * maxIntWidth) {
1262 auto newShape = to_vector(ty.getShape());
1263 newShape.push_back(2);
1264 return VectorType::get(newShape,
1265 IntegerType::get(ty.getContext(), maxIntWidth));
1266 }
1267
1268 return nullptr;
1269 });
1270
1271 // Function case.
1272 addConversion(callback: [this](FunctionType ty) -> std::optional<Type> {
1273 // Convert inputs and results, e.g.:
1274 // (i2N, i2N) -> i2N --> (vector<2xiN>, vector<2xiN>) -> vector<2xiN>
1275 SmallVector<Type> inputs;
1276 if (failed(convertTypes(types: ty.getInputs(), results&: inputs)))
1277 return nullptr;
1278
1279 SmallVector<Type> results;
1280 if (failed(convertTypes(types: ty.getResults(), results)))
1281 return nullptr;
1282
1283 return FunctionType::get(ty.getContext(), inputs, results);
1284 });
1285}
1286
1287void arith::populateArithWideIntEmulationPatterns(
1288 const WideIntEmulationConverter &typeConverter,
1289 RewritePatternSet &patterns) {
1290 // Populate `arith.*` conversion patterns.
1291 patterns.add<
1292 // Misc ops.
1293 ConvertConstant, ConvertCmpI, ConvertSelect, ConvertVectorPrint,
1294 // Binary ops.
1295 ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRSI, ConvertShRUI,
1296 ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>,
1297 ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>,
1298 ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>,
1299 ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>, ConvertSubI,
1300 // Bitwise binary ops.
1301 ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
1302 ConvertBitwiseBinary<arith::XOrIOp>,
1303 // Extension and truncation ops.
1304 ConvertExtSI, ConvertExtUI, ConvertTruncI,
1305 // Cast ops.
1306 ConvertIndexCastIntToIndex<arith::IndexCastOp>,
1307 ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
1308 ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
1309 ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>,
1310 ConvertSIToFP, ConvertUIToFP, ConvertFPToUI, ConvertFPToSI>(
1311 typeConverter, patterns.getContext());
1312}
1313

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp