1//===- TestEmulateNarrowType.cpp - Test Narrow Type Emulation ------*- c++
2//-*-===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
10#include "mlir/Dialect/Affine/IR/AffineOps.h"
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
13#include "mlir/Dialect/Arith/Transforms/Passes.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
17#include "mlir/Dialect/Vector/IR/VectorOps.h"
18#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
19#include "mlir/Pass/Pass.h"
20#include "mlir/Transforms/DialectConversion.h"
21
22using namespace mlir;
23
24namespace {
25
26struct TestEmulateNarrowTypePass
27 : public PassWrapper<TestEmulateNarrowTypePass,
28 OperationPass<func::FuncOp>> {
29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateNarrowTypePass)
30
31 TestEmulateNarrowTypePass() = default;
32 TestEmulateNarrowTypePass(const TestEmulateNarrowTypePass &pass)
33 : PassWrapper(pass) {}
34
35 void getDependentDialects(DialectRegistry &registry) const override {
36 registry
37 .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
38 vector::VectorDialect, affine::AffineDialect>();
39 }
40 StringRef getArgument() const final { return "test-emulate-narrow-int"; }
41 StringRef getDescription() const final {
42 return "Function pass to test Narrow Integer Emulation";
43 }
44
45 void runOnOperation() override {
46 if (!llvm::isPowerOf2_32(loadStoreEmulateBitwidth) ||
47 loadStoreEmulateBitwidth < 8) {
48 signalPassFailure();
49 return;
50 }
51
52 Operation *op = getOperation();
53 MLIRContext *ctx = op->getContext();
54
55 arith::NarrowTypeEmulationConverter typeConverter(loadStoreEmulateBitwidth);
56
57 // Convert scalar type.
58 typeConverter.addConversion(callback: [this](IntegerType ty) -> std::optional<Type> {
59 unsigned width = ty.getWidth();
60 if (width >= arithComputeBitwidth)
61 return ty;
62
63 return IntegerType::get(ty.getContext(), arithComputeBitwidth);
64 });
65
66 // Convert vector type.
67 typeConverter.addConversion(callback: [this](VectorType ty) -> std::optional<Type> {
68 auto intTy = dyn_cast<IntegerType>(ty.getElementType());
69 if (!intTy)
70 return ty;
71
72 unsigned width = intTy.getWidth();
73 if (width >= arithComputeBitwidth)
74 return ty;
75
76 return VectorType::get(
77 to_vector(ty.getShape()),
78 IntegerType::get(ty.getContext(), arithComputeBitwidth));
79 });
80
81 memref::populateMemRefNarrowTypeEmulationConversions(typeConverter);
82 ConversionTarget target(*ctx);
83 target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
84 return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
85 });
86 auto opLegalCallback = [&typeConverter](Operation *op) {
87 return typeConverter.isLegal(op);
88 };
89 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
90 target.addDynamicallyLegalDialect<
91 arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect,
92 affine::AffineDialect>(opLegalCallback);
93
94 RewritePatternSet patterns(ctx);
95
96 arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
97 memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
98 vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
99
100 if (failed(applyPartialConversion(op, target, std::move(patterns))))
101 signalPassFailure();
102 }
103
104 Option<unsigned> loadStoreEmulateBitwidth{
105 *this, "memref-load-bitwidth",
106 llvm::cl::desc("memref load/store emulation bit width"),
107 llvm::cl::init(8)};
108
109 Option<unsigned> arithComputeBitwidth{
110 *this, "arith-compute-bitwidth",
111 llvm::cl::desc("arith computation bit width"), llvm::cl::init(4)};
112};
113} // namespace
114
115namespace mlir::test {
116void registerTestEmulateNarrowTypePass() {
117 PassRegistration<TestEmulateNarrowTypePass>();
118}
119} // namespace mlir::test
120

source code of mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp