1//===- TestWideIntEmulation.cpp - Test Wide Int Emulation ------*- c++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass for integration testing of wide integer
10// emulation patterns. Applies conversion patterns only to functions whose
11// names start with a specified prefix.
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Arith/Transforms/Passes.h"
17#include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
21#include "mlir/Pass/Pass.h"
22#include "mlir/Transforms/DialectConversion.h"
23
24using namespace mlir;
25
26namespace {
27struct TestEmulateWideIntPass
28 : public PassWrapper<TestEmulateWideIntPass, OperationPass<func::FuncOp>> {
29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateWideIntPass)
30
31 TestEmulateWideIntPass() = default;
32 TestEmulateWideIntPass(const TestEmulateWideIntPass &pass)
33 : PassWrapper(pass) {}
34
35 void getDependentDialects(DialectRegistry &registry) const override {
36 registry.insert<arith::ArithDialect, func::FuncDialect, LLVM::LLVMDialect,
37 vector::VectorDialect>();
38 }
39 StringRef getArgument() const final { return "test-arith-emulate-wide-int"; }
40 StringRef getDescription() const final {
41 return "Function pass to test Wide Integer Emulation";
42 }
43
44 void runOnOperation() override {
45 if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
46 signalPassFailure();
47 return;
48 }
49
50 func::FuncOp op = getOperation();
51 if (!op.getSymName().starts_with(testFunctionPrefix))
52 return;
53
54 MLIRContext *ctx = op.getContext();
55 arith::WideIntEmulationConverter typeConverter(widestIntSupported);
56
57 // Use `llvm.bitcast` as the bridge so that we can use preserve the
58 // function argument and return types of the processed function.
59 // TODO: Consider extending `arith.bitcast` to support scalar-to-1D-vector
60 // casts (and vice versa) and using it insted of `llvm.bitcast`.
61 auto addBitcast = [](OpBuilder &builder, Type type, ValueRange inputs,
62 Location loc) -> std::optional<Value> {
63 auto cast = builder.create<LLVM::BitcastOp>(loc, type, inputs);
64 return cast->getResult(0);
65 };
66 typeConverter.addSourceMaterialization(callback&: addBitcast);
67 typeConverter.addTargetMaterialization(callback&: addBitcast);
68
69 ConversionTarget target(*ctx);
70 target
71 .addDynamicallyLegalDialect<arith::ArithDialect, vector::VectorDialect>(
72 [&typeConverter](Operation *op) {
73 return typeConverter.isLegal(op);
74 });
75
76 RewritePatternSet patterns(ctx);
77 arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
78 if (failed(applyPartialConversion(op, target, std::move(patterns))))
79 signalPassFailure();
80 }
81
82 Option<std::string> testFunctionPrefix{
83 *this, "function-prefix",
84 llvm::cl::desc("Prefix of functions to run the emulation pass on"),
85 llvm::cl::init("emulate_")};
86 Option<unsigned> widestIntSupported{
87 *this, "widest-int-supported",
88 llvm::cl::desc("Maximum integer bit width supported by the target"),
89 llvm::cl::init(32)};
90};
91} // namespace
92
93namespace mlir::test {
94void registerTestArithEmulateWideIntPass() {
95 PassRegistration<TestEmulateWideIntPass>();
96}
97} // namespace mlir::test
98

source code of mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp