1 | //===- CharacterConversion.cpp -- convert between character encodings -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
10 | #include "flang/Optimizer/Dialect/FIROps.h" |
11 | #include "flang/Optimizer/Dialect/FIRType.h" |
12 | #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
13 | #include "flang/Optimizer/Dialect/Support/KindMapping.h" |
14 | #include "flang/Optimizer/Transforms/Passes.h" |
15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
16 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
17 | #include "mlir/IR/Diagnostics.h" |
18 | #include "mlir/Pass/Pass.h" |
19 | #include "mlir/Transforms/DialectConversion.h" |
20 | #include "llvm/Support/Debug.h" |
21 | |
22 | namespace fir { |
23 | #define GEN_PASS_DEF_CHARACTERCONVERSION |
24 | #include "flang/Optimizer/Transforms/Passes.h.inc" |
25 | } // namespace fir |
26 | |
27 | #define DEBUG_TYPE "flang-character-conversion" |
28 | |
29 | namespace { |
30 | |
31 | // TODO: Future hook to select some set of runtime calls. |
32 | struct CharacterConversionOptions { |
33 | std::string runtimeName; |
34 | }; |
35 | |
36 | class CharacterConvertConversion |
37 | : public mlir::OpRewritePattern<fir::CharConvertOp> { |
38 | public: |
39 | using OpRewritePattern::OpRewritePattern; |
40 | |
41 | mlir::LogicalResult |
42 | matchAndRewrite(fir::CharConvertOp conv, |
43 | mlir::PatternRewriter &rewriter) const override { |
44 | auto kindMap = fir::getKindMapping(conv->getParentOfType<mlir::ModuleOp>()); |
45 | auto loc = conv.getLoc(); |
46 | |
47 | LLVM_DEBUG(llvm::dbgs() |
48 | << "running character conversion on " << conv << '\n'); |
49 | |
50 | // Establish a loop that executes count iterations. |
51 | auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); |
52 | auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); |
53 | auto idxTy = rewriter.getIndexType(); |
54 | auto castCnt = rewriter.create<fir::ConvertOp>(loc, idxTy, conv.getCount()); |
55 | auto countm1 = rewriter.create<mlir::arith::SubIOp>(loc, castCnt, one); |
56 | auto loop = rewriter.create<fir::DoLoopOp>(loc, zero, countm1, one); |
57 | auto insPt = rewriter.saveInsertionPoint(); |
58 | rewriter.setInsertionPointToStart(loop.getBody()); |
59 | |
60 | // For each code point in the `from` string, convert naively to the `to` |
61 | // string code point. Conversion is done blindly on size only, not value. |
62 | auto getCharBits = [&](mlir::Type t) { |
63 | auto chrTy = fir::unwrapSequenceType(fir::dyn_cast_ptrEleTy(t)) |
64 | .cast<fir::CharacterType>(); |
65 | return kindMap.getCharacterBitsize(chrTy.getFKind()); |
66 | }; |
67 | auto fromBits = getCharBits(conv.getFrom().getType()); |
68 | auto toBits = getCharBits(conv.getTo().getType()); |
69 | auto pointerType = [&](unsigned bits) { |
70 | return fir::ReferenceType::get(fir::SequenceType::get( |
71 | fir::SequenceType::ShapeRef{fir::SequenceType::getUnknownExtent()}, |
72 | rewriter.getIntegerType(bits))); |
73 | }; |
74 | auto fromPtrTy = pointerType(fromBits); |
75 | auto toTy = rewriter.getIntegerType(toBits); |
76 | auto toPtrTy = pointerType(toBits); |
77 | auto fromPtr = |
78 | rewriter.create<fir::ConvertOp>(loc, fromPtrTy, conv.getFrom()); |
79 | auto toPtr = rewriter.create<fir::ConvertOp>(loc, toPtrTy, conv.getTo()); |
80 | auto getEleTy = [&](unsigned bits) { |
81 | return fir::ReferenceType::get(rewriter.getIntegerType(bits)); |
82 | }; |
83 | auto fromi = rewriter.create<fir::CoordinateOp>( |
84 | loc, getEleTy(fromBits), fromPtr, |
85 | mlir::ValueRange{loop.getInductionVar()}); |
86 | auto toi = rewriter.create<fir::CoordinateOp>( |
87 | loc, getEleTy(toBits), toPtr, mlir::ValueRange{loop.getInductionVar()}); |
88 | auto load = rewriter.create<fir::LoadOp>(loc, fromi); |
89 | mlir::Value icast = |
90 | (fromBits >= toBits) |
91 | ? rewriter.create<fir::ConvertOp>(loc, toTy, load).getResult() |
92 | : rewriter.create<mlir::arith::ExtUIOp>(loc, toTy, load) |
93 | .getResult(); |
94 | rewriter.replaceOpWithNewOp<fir::StoreOp>(conv, icast, toi); |
95 | rewriter.restoreInsertionPoint(insPt); |
96 | return mlir::success(); |
97 | } |
98 | }; |
99 | |
100 | /// Rewrite the `fir.char_convert` op into a loop. This pass must be run only on |
101 | /// fir::CharConvertOp. |
102 | class CharacterConversion |
103 | : public fir::impl::CharacterConversionBase<CharacterConversion> { |
104 | public: |
105 | void runOnOperation() override { |
106 | CharacterConversionOptions clOpts{useRuntimeCalls.getValue()}; |
107 | if (clOpts.runtimeName.empty()) { |
108 | auto *context = &getContext(); |
109 | auto *func = getOperation(); |
110 | mlir::RewritePatternSet patterns(context); |
111 | patterns.insert<CharacterConvertConversion>(context); |
112 | mlir::ConversionTarget target(*context); |
113 | target.addLegalDialect<mlir::affine::AffineDialect, fir::FIROpsDialect, |
114 | mlir::arith::ArithDialect, |
115 | mlir::func::FuncDialect>(); |
116 | |
117 | // apply the patterns |
118 | target.addIllegalOp<fir::CharConvertOp>(); |
119 | if (mlir::failed(mlir::applyPartialConversion(func, target, |
120 | std::move(patterns)))) { |
121 | mlir::emitError(mlir::UnknownLoc::get(context), |
122 | "error in rewriting character convert op" ); |
123 | signalPassFailure(); |
124 | } |
125 | return; |
126 | } |
127 | |
128 | // TODO: some sort of runtime supported conversion? |
129 | signalPassFailure(); |
130 | } |
131 | }; |
132 | } // end anonymous namespace |
133 | |
134 | std::unique_ptr<mlir::Pass> fir::createCharacterConversionPass() { |
135 | return std::make_unique<CharacterConversion>(); |
136 | } |
137 | |