1 | //===- LowerHLFIRIntrinsics.cpp - Transformational intrinsics to FIR ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
10 | #include "flang/Optimizer/Builder/HLFIRTools.h" |
11 | #include "flang/Optimizer/Builder/IntrinsicCall.h" |
12 | #include "flang/Optimizer/Builder/Todo.h" |
13 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
14 | #include "flang/Optimizer/Dialect/FIROps.h" |
15 | #include "flang/Optimizer/Dialect/FIRType.h" |
16 | #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
17 | #include "flang/Optimizer/HLFIR/HLFIRDialect.h" |
18 | #include "flang/Optimizer/HLFIR/HLFIROps.h" |
19 | #include "flang/Optimizer/HLFIR/Passes.h" |
20 | #include "mlir/IR/BuiltinDialect.h" |
21 | #include "mlir/IR/MLIRContext.h" |
22 | #include "mlir/IR/PatternMatch.h" |
23 | #include "mlir/Pass/Pass.h" |
24 | #include "mlir/Pass/PassManager.h" |
25 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
26 | #include <optional> |
27 | |
28 | namespace hlfir { |
29 | #define GEN_PASS_DEF_LOWERHLFIRINTRINSICS |
30 | #include "flang/Optimizer/HLFIR/Passes.h.inc" |
31 | } // namespace hlfir |
32 | |
33 | namespace { |
34 | |
35 | /// Base class for passes converting transformational intrinsic operations into |
36 | /// runtime calls |
37 | template <class OP> |
38 | class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> { |
39 | public: |
40 | explicit HlfirIntrinsicConversion(mlir::MLIRContext *ctx) |
41 | : mlir::OpRewritePattern<OP>{ctx} { |
42 | // required for cases where intrinsics are chained together e.g. |
43 | // matmul(matmul(a, b), c) |
44 | // because converting the inner operation then invalidates the |
45 | // outer operation: causing the pattern to apply recursively. |
46 | // |
47 | // This is safe because we always progress with each iteration. Circular |
48 | // applications of operations are not expressible in MLIR because we use |
49 | // an SSA form and one must become first. E.g. |
50 | // %a = hlfir.matmul %b %d |
51 | // %b = hlfir.matmul %a %d |
52 | // cannot be written. |
53 | // MSVC needs the this-> |
54 | this->setHasBoundedRewriteRecursion(true); |
55 | } |
56 | |
57 | protected: |
58 | struct IntrinsicArgument { |
59 | mlir::Value val; // allowed to be null if the argument is absent |
60 | mlir::Type desiredType; |
61 | }; |
62 | |
63 | /// Lower the arguments to the intrinsic: adding necessary boxing and |
64 | /// conversion to match the signature of the intrinsic in the runtime library. |
65 | llvm::SmallVector<fir::ExtendedValue, 3> |
66 | lowerArguments(mlir::Operation *op, |
67 | const llvm::ArrayRef<IntrinsicArgument> &args, |
68 | mlir::PatternRewriter &rewriter, |
69 | const fir::IntrinsicArgumentLoweringRules *argLowering) const { |
70 | mlir::Location loc = op->getLoc(); |
71 | fir::FirOpBuilder builder{rewriter, op}; |
72 | |
73 | llvm::SmallVector<fir::ExtendedValue, 3> ret; |
74 | llvm::SmallVector<std::function<void()>, 2> cleanupFns; |
75 | |
76 | for (size_t i = 0; i < args.size(); ++i) { |
77 | mlir::Value arg = args[i].val; |
78 | mlir::Type desiredType = args[i].desiredType; |
79 | if (!arg) { |
80 | ret.emplace_back(fir::getAbsentIntrinsicArgument()); |
81 | continue; |
82 | } |
83 | hlfir::Entity entity{arg}; |
84 | |
85 | fir::ArgLoweringRule argRules = |
86 | fir::lowerIntrinsicArgumentAs(*argLowering, i); |
87 | switch (argRules.lowerAs) { |
88 | case fir::LowerIntrinsicArgAs::Value: { |
89 | if (args[i].desiredType != arg.getType()) { |
90 | arg = builder.createConvert(loc, desiredType, arg); |
91 | entity = hlfir::Entity{arg}; |
92 | } |
93 | auto [exv, cleanup] = hlfir::convertToValue(loc, builder, entity); |
94 | if (cleanup) |
95 | cleanupFns.push_back(*cleanup); |
96 | ret.emplace_back(exv); |
97 | } break; |
98 | case fir::LowerIntrinsicArgAs::Addr: { |
99 | auto [exv, cleanup] = |
100 | hlfir::convertToAddress(loc, builder, entity, desiredType); |
101 | if (cleanup) |
102 | cleanupFns.push_back(*cleanup); |
103 | ret.emplace_back(exv); |
104 | } break; |
105 | case fir::LowerIntrinsicArgAs::Box: { |
106 | auto [box, cleanup] = |
107 | hlfir::convertToBox(loc, builder, entity, desiredType); |
108 | if (cleanup) |
109 | cleanupFns.push_back(*cleanup); |
110 | ret.emplace_back(box); |
111 | } break; |
112 | case fir::LowerIntrinsicArgAs::Inquired: { |
113 | if (args[i].desiredType != arg.getType()) { |
114 | arg = builder.createConvert(loc, desiredType, arg); |
115 | entity = hlfir::Entity{arg}; |
116 | } |
117 | // Place hlfir.expr in memory, and unbox fir.boxchar. Other entities |
118 | // are translated to fir::ExtendedValue without transofrmation (notably, |
119 | // pointers/allocatable are not dereferenced). |
120 | // TODO: once lowering to FIR retires, UBOUND and LBOUND can be |
121 | // simplified since the fir.box lowered here are now guarenteed to |
122 | // contain the local lower bounds thanks to the hlfir.declare (the extra |
123 | // rebox can be removed). |
124 | // When taking arguments as descriptors, the runtime expect absent |
125 | // OPTIONAL to be a nullptr to a descriptor, lowering has already |
126 | // prepared such descriptors as needed, hence set |
127 | // keepScalarOptionalBoxed to avoid building descriptors with a null |
128 | // address for them. |
129 | auto [exv, cleanup] = hlfir::translateToExtendedValue( |
130 | loc, builder, entity, /*contiguous=*/false, |
131 | /*keepScalarOptionalBoxed=*/true); |
132 | if (cleanup) |
133 | cleanupFns.push_back(*cleanup); |
134 | ret.emplace_back(exv); |
135 | } break; |
136 | } |
137 | } |
138 | |
139 | if (cleanupFns.size()) { |
140 | auto oldInsertionPoint = builder.saveInsertionPoint(); |
141 | builder.setInsertionPointAfter(op); |
142 | for (std::function<void()> cleanup : cleanupFns) |
143 | cleanup(); |
144 | builder.restoreInsertionPoint(oldInsertionPoint); |
145 | } |
146 | |
147 | return ret; |
148 | } |
149 | |
150 | void processReturnValue(mlir::Operation *op, |
151 | const fir::ExtendedValue &resultExv, bool mustBeFreed, |
152 | fir::FirOpBuilder &builder, |
153 | mlir::PatternRewriter &rewriter) const { |
154 | mlir::Location loc = op->getLoc(); |
155 | |
156 | mlir::Value firBase = fir::getBase(resultExv); |
157 | mlir::Type firBaseTy = firBase.getType(); |
158 | |
159 | std::optional<hlfir::EntityWithAttributes> resultEntity; |
160 | if (fir::isa_trivial(firBaseTy)) { |
161 | // Some intrinsics return i1 when the original operation |
162 | // produces fir.logical<>, so we may need to cast it. |
163 | firBase = builder.createConvert(loc, op->getResult(0).getType(), firBase); |
164 | resultEntity = hlfir::EntityWithAttributes{firBase}; |
165 | } else { |
166 | resultEntity = |
167 | hlfir::genDeclare(loc, builder, resultExv, ".tmp.intrinsic_result" , |
168 | fir::FortranVariableFlagsAttr{}); |
169 | } |
170 | |
171 | if (resultEntity->isVariable()) { |
172 | hlfir::AsExprOp asExpr = builder.create<hlfir::AsExprOp>( |
173 | loc, *resultEntity, builder.createBool(loc, mustBeFreed)); |
174 | resultEntity = hlfir::EntityWithAttributes{asExpr.getResult()}; |
175 | } |
176 | |
177 | mlir::Value base = resultEntity->getBase(); |
178 | if (!mlir::isa<hlfir::ExprType>(base.getType())) { |
179 | for (mlir::Operation *use : op->getResult(0).getUsers()) { |
180 | if (mlir::isa<hlfir::DestroyOp>(use)) |
181 | rewriter.eraseOp(use); |
182 | } |
183 | } |
184 | |
185 | rewriter.replaceOp(op, base); |
186 | } |
187 | }; |
188 | |
189 | // Given an integer or array of integer type, calculate the Kind parameter from |
190 | // the width for use in runtime intrinsic calls. |
191 | static unsigned getKindForType(mlir::Type ty) { |
192 | mlir::Type eltty = hlfir::getFortranElementType(ty); |
193 | unsigned width = mlir::cast<mlir::IntegerType>(eltty).getWidth(); |
194 | return width / 8; |
195 | } |
196 | |
197 | template <class OP> |
198 | class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> { |
199 | using HlfirIntrinsicConversion<OP>::HlfirIntrinsicConversion; |
200 | using IntrinsicArgument = |
201 | typename HlfirIntrinsicConversion<OP>::IntrinsicArgument; |
202 | using HlfirIntrinsicConversion<OP>::lowerArguments; |
203 | using HlfirIntrinsicConversion<OP>::processReturnValue; |
204 | |
205 | protected: |
206 | auto buildNumericalArgs(OP operation, mlir::Type i32, mlir::Type logicalType, |
207 | mlir::PatternRewriter &rewriter, |
208 | std::string opName) const { |
209 | llvm::SmallVector<IntrinsicArgument, 3> inArgs; |
210 | inArgs.push_back({operation.getArray(), operation.getArray().getType()}); |
211 | inArgs.push_back({operation.getDim(), i32}); |
212 | inArgs.push_back({operation.getMask(), logicalType}); |
213 | auto *argLowering = fir::getIntrinsicArgumentLowering(opName); |
214 | return lowerArguments(operation, inArgs, rewriter, argLowering); |
215 | }; |
216 | |
217 | auto buildMinMaxLocArgs(OP operation, mlir::Type i32, mlir::Type logicalType, |
218 | mlir::PatternRewriter &rewriter, std::string opName, |
219 | fir::FirOpBuilder builder) const { |
220 | llvm::SmallVector<IntrinsicArgument, 3> inArgs; |
221 | inArgs.push_back({operation.getArray(), operation.getArray().getType()}); |
222 | inArgs.push_back({operation.getDim(), i32}); |
223 | inArgs.push_back({operation.getMask(), logicalType}); |
224 | mlir::Value kind = builder.createIntegerConstant( |
225 | operation->getLoc(), i32, getKindForType(operation.getType())); |
226 | inArgs.push_back({kind, i32}); |
227 | inArgs.push_back({operation.getBack(), i32}); |
228 | auto *argLowering = fir::getIntrinsicArgumentLowering(opName); |
229 | return lowerArguments(operation, inArgs, rewriter, argLowering); |
230 | }; |
231 | |
232 | auto buildLogicalArgs(OP operation, mlir::Type i32, mlir::Type logicalType, |
233 | mlir::PatternRewriter &rewriter, |
234 | std::string opName) const { |
235 | llvm::SmallVector<IntrinsicArgument, 2> inArgs; |
236 | inArgs.push_back({operation.getMask(), logicalType}); |
237 | inArgs.push_back({operation.getDim(), i32}); |
238 | auto *argLowering = fir::getIntrinsicArgumentLowering(opName); |
239 | return lowerArguments(operation, inArgs, rewriter, argLowering); |
240 | }; |
241 | |
242 | public: |
243 | llvm::LogicalResult |
244 | matchAndRewrite(OP operation, |
245 | mlir::PatternRewriter &rewriter) const override { |
246 | std::string opName; |
247 | if constexpr (std::is_same_v<OP, hlfir::SumOp>) { |
248 | opName = "sum" ; |
249 | } else if constexpr (std::is_same_v<OP, hlfir::ProductOp>) { |
250 | opName = "product" ; |
251 | } else if constexpr (std::is_same_v<OP, hlfir::MaxvalOp>) { |
252 | opName = "maxval" ; |
253 | } else if constexpr (std::is_same_v<OP, hlfir::MinvalOp>) { |
254 | opName = "minval" ; |
255 | } else if constexpr (std::is_same_v<OP, hlfir::MinlocOp>) { |
256 | opName = "minloc" ; |
257 | } else if constexpr (std::is_same_v<OP, hlfir::MaxlocOp>) { |
258 | opName = "maxloc" ; |
259 | } else if constexpr (std::is_same_v<OP, hlfir::AnyOp>) { |
260 | opName = "any" ; |
261 | } else if constexpr (std::is_same_v<OP, hlfir::AllOp>) { |
262 | opName = "all" ; |
263 | } else { |
264 | return mlir::failure(); |
265 | } |
266 | |
267 | fir::FirOpBuilder builder{rewriter, operation.getOperation()}; |
268 | const mlir::Location &loc = operation->getLoc(); |
269 | |
270 | mlir::Type i32 = builder.getI32Type(); |
271 | mlir::Type logicalType = fir::LogicalType::get( |
272 | builder.getContext(), builder.getKindMap().defaultLogicalKind()); |
273 | |
274 | llvm::SmallVector<fir::ExtendedValue, 0> args; |
275 | |
276 | if constexpr (std::is_same_v<OP, hlfir::SumOp> || |
277 | std::is_same_v<OP, hlfir::ProductOp> || |
278 | std::is_same_v<OP, hlfir::MaxvalOp> || |
279 | std::is_same_v<OP, hlfir::MinvalOp>) { |
280 | args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName); |
281 | } else if constexpr (std::is_same_v<OP, hlfir::MinlocOp> || |
282 | std::is_same_v<OP, hlfir::MaxlocOp>) { |
283 | args = buildMinMaxLocArgs(operation, i32, logicalType, rewriter, opName, |
284 | builder); |
285 | } else { |
286 | args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName); |
287 | } |
288 | |
289 | mlir::Type scalarResultType = |
290 | hlfir::getFortranElementType(operation.getType()); |
291 | |
292 | auto [resultExv, mustBeFreed] = |
293 | fir::genIntrinsicCall(builder, loc, opName, scalarResultType, args); |
294 | |
295 | processReturnValue(operation, resultExv, mustBeFreed, builder, rewriter); |
296 | return mlir::success(); |
297 | } |
298 | }; |
299 | |
300 | using SumOpConversion = HlfirReductionIntrinsicConversion<hlfir::SumOp>; |
301 | |
302 | using ProductOpConversion = HlfirReductionIntrinsicConversion<hlfir::ProductOp>; |
303 | |
304 | using MaxvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxvalOp>; |
305 | |
306 | using MinvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinvalOp>; |
307 | |
308 | using MinlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinlocOp>; |
309 | |
310 | using MaxlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxlocOp>; |
311 | |
312 | using AnyOpConversion = HlfirReductionIntrinsicConversion<hlfir::AnyOp>; |
313 | |
314 | using AllOpConversion = HlfirReductionIntrinsicConversion<hlfir::AllOp>; |
315 | |
316 | struct CountOpConversion : public HlfirIntrinsicConversion<hlfir::CountOp> { |
317 | using HlfirIntrinsicConversion<hlfir::CountOp>::HlfirIntrinsicConversion; |
318 | |
319 | llvm::LogicalResult |
320 | matchAndRewrite(hlfir::CountOp count, |
321 | mlir::PatternRewriter &rewriter) const override { |
322 | fir::FirOpBuilder builder{rewriter, count.getOperation()}; |
323 | const mlir::Location &loc = count->getLoc(); |
324 | |
325 | mlir::Type i32 = builder.getI32Type(); |
326 | mlir::Type logicalType = fir::LogicalType::get( |
327 | builder.getContext(), builder.getKindMap().defaultLogicalKind()); |
328 | |
329 | llvm::SmallVector<IntrinsicArgument, 3> inArgs; |
330 | inArgs.push_back({count.getMask(), logicalType}); |
331 | inArgs.push_back({count.getDim(), i32}); |
332 | mlir::Value kind = builder.createIntegerConstant( |
333 | count->getLoc(), i32, getKindForType(count.getType())); |
334 | inArgs.push_back({kind, i32}); |
335 | |
336 | auto *argLowering = fir::getIntrinsicArgumentLowering("count" ); |
337 | llvm::SmallVector<fir::ExtendedValue, 3> args = |
338 | lowerArguments(count, inArgs, rewriter, argLowering); |
339 | |
340 | mlir::Type scalarResultType = hlfir::getFortranElementType(count.getType()); |
341 | |
342 | auto [resultExv, mustBeFreed] = |
343 | fir::genIntrinsicCall(builder, loc, "count" , scalarResultType, args); |
344 | |
345 | processReturnValue(count, resultExv, mustBeFreed, builder, rewriter); |
346 | return mlir::success(); |
347 | } |
348 | }; |
349 | |
350 | struct MatmulOpConversion : public HlfirIntrinsicConversion<hlfir::MatmulOp> { |
351 | using HlfirIntrinsicConversion<hlfir::MatmulOp>::HlfirIntrinsicConversion; |
352 | |
353 | llvm::LogicalResult |
354 | matchAndRewrite(hlfir::MatmulOp matmul, |
355 | mlir::PatternRewriter &rewriter) const override { |
356 | fir::FirOpBuilder builder{rewriter, matmul.getOperation()}; |
357 | const mlir::Location &loc = matmul->getLoc(); |
358 | |
359 | mlir::Value lhs = matmul.getLhs(); |
360 | mlir::Value rhs = matmul.getRhs(); |
361 | llvm::SmallVector<IntrinsicArgument, 2> inArgs; |
362 | inArgs.push_back({lhs, lhs.getType()}); |
363 | inArgs.push_back({rhs, rhs.getType()}); |
364 | |
365 | auto *argLowering = fir::getIntrinsicArgumentLowering("matmul" ); |
366 | llvm::SmallVector<fir::ExtendedValue, 2> args = |
367 | lowerArguments(matmul, inArgs, rewriter, argLowering); |
368 | |
369 | mlir::Type scalarResultType = |
370 | hlfir::getFortranElementType(matmul.getType()); |
371 | |
372 | auto [resultExv, mustBeFreed] = |
373 | fir::genIntrinsicCall(builder, loc, "matmul" , scalarResultType, args); |
374 | |
375 | processReturnValue(matmul, resultExv, mustBeFreed, builder, rewriter); |
376 | return mlir::success(); |
377 | } |
378 | }; |
379 | |
380 | struct DotProductOpConversion |
381 | : public HlfirIntrinsicConversion<hlfir::DotProductOp> { |
382 | using HlfirIntrinsicConversion<hlfir::DotProductOp>::HlfirIntrinsicConversion; |
383 | |
384 | llvm::LogicalResult |
385 | matchAndRewrite(hlfir::DotProductOp dotProduct, |
386 | mlir::PatternRewriter &rewriter) const override { |
387 | fir::FirOpBuilder builder{rewriter, dotProduct.getOperation()}; |
388 | const mlir::Location &loc = dotProduct->getLoc(); |
389 | |
390 | mlir::Value lhs = dotProduct.getLhs(); |
391 | mlir::Value rhs = dotProduct.getRhs(); |
392 | llvm::SmallVector<IntrinsicArgument, 2> inArgs; |
393 | inArgs.push_back({lhs, lhs.getType()}); |
394 | inArgs.push_back({rhs, rhs.getType()}); |
395 | |
396 | auto *argLowering = fir::getIntrinsicArgumentLowering("dot_product" ); |
397 | llvm::SmallVector<fir::ExtendedValue, 2> args = |
398 | lowerArguments(dotProduct, inArgs, rewriter, argLowering); |
399 | |
400 | mlir::Type scalarResultType = |
401 | hlfir::getFortranElementType(dotProduct.getType()); |
402 | |
403 | auto [resultExv, mustBeFreed] = fir::genIntrinsicCall( |
404 | builder, loc, "dot_product" , scalarResultType, args); |
405 | |
406 | processReturnValue(dotProduct, resultExv, mustBeFreed, builder, rewriter); |
407 | return mlir::success(); |
408 | } |
409 | }; |
410 | |
411 | class TransposeOpConversion |
412 | : public HlfirIntrinsicConversion<hlfir::TransposeOp> { |
413 | using HlfirIntrinsicConversion<hlfir::TransposeOp>::HlfirIntrinsicConversion; |
414 | |
415 | llvm::LogicalResult |
416 | matchAndRewrite(hlfir::TransposeOp transpose, |
417 | mlir::PatternRewriter &rewriter) const override { |
418 | fir::FirOpBuilder builder{rewriter, transpose.getOperation()}; |
419 | const mlir::Location &loc = transpose->getLoc(); |
420 | |
421 | mlir::Value arg = transpose.getArray(); |
422 | llvm::SmallVector<IntrinsicArgument, 1> inArgs; |
423 | inArgs.push_back({arg, arg.getType()}); |
424 | |
425 | auto *argLowering = fir::getIntrinsicArgumentLowering("transpose" ); |
426 | llvm::SmallVector<fir::ExtendedValue, 1> args = |
427 | lowerArguments(transpose, inArgs, rewriter, argLowering); |
428 | |
429 | mlir::Type scalarResultType = |
430 | hlfir::getFortranElementType(transpose.getType()); |
431 | |
432 | auto [resultExv, mustBeFreed] = fir::genIntrinsicCall( |
433 | builder, loc, "transpose" , scalarResultType, args); |
434 | |
435 | processReturnValue(transpose, resultExv, mustBeFreed, builder, rewriter); |
436 | return mlir::success(); |
437 | } |
438 | }; |
439 | |
440 | struct MatmulTransposeOpConversion |
441 | : public HlfirIntrinsicConversion<hlfir::MatmulTransposeOp> { |
442 | using HlfirIntrinsicConversion< |
443 | hlfir::MatmulTransposeOp>::HlfirIntrinsicConversion; |
444 | |
445 | llvm::LogicalResult |
446 | matchAndRewrite(hlfir::MatmulTransposeOp multranspose, |
447 | mlir::PatternRewriter &rewriter) const override { |
448 | fir::FirOpBuilder builder{rewriter, multranspose.getOperation()}; |
449 | const mlir::Location &loc = multranspose->getLoc(); |
450 | |
451 | mlir::Value lhs = multranspose.getLhs(); |
452 | mlir::Value rhs = multranspose.getRhs(); |
453 | llvm::SmallVector<IntrinsicArgument, 2> inArgs; |
454 | inArgs.push_back({lhs, lhs.getType()}); |
455 | inArgs.push_back({rhs, rhs.getType()}); |
456 | |
457 | auto *argLowering = fir::getIntrinsicArgumentLowering("matmul" ); |
458 | llvm::SmallVector<fir::ExtendedValue, 2> args = |
459 | lowerArguments(multranspose, inArgs, rewriter, argLowering); |
460 | |
461 | mlir::Type scalarResultType = |
462 | hlfir::getFortranElementType(multranspose.getType()); |
463 | |
464 | auto [resultExv, mustBeFreed] = fir::genIntrinsicCall( |
465 | builder, loc, "matmul_transpose" , scalarResultType, args); |
466 | |
467 | processReturnValue(multranspose, resultExv, mustBeFreed, builder, rewriter); |
468 | return mlir::success(); |
469 | } |
470 | }; |
471 | |
472 | class CShiftOpConversion : public HlfirIntrinsicConversion<hlfir::CShiftOp> { |
473 | using HlfirIntrinsicConversion<hlfir::CShiftOp>::HlfirIntrinsicConversion; |
474 | |
475 | llvm::LogicalResult |
476 | matchAndRewrite(hlfir::CShiftOp cshift, |
477 | mlir::PatternRewriter &rewriter) const override { |
478 | fir::FirOpBuilder builder{rewriter, cshift.getOperation()}; |
479 | const mlir::Location &loc = cshift->getLoc(); |
480 | |
481 | llvm::SmallVector<IntrinsicArgument, 3> inArgs; |
482 | mlir::Value array = cshift.getArray(); |
483 | inArgs.push_back({array, array.getType()}); |
484 | mlir::Value shift = cshift.getShift(); |
485 | inArgs.push_back({shift, shift.getType()}); |
486 | inArgs.push_back({cshift.getDim(), builder.getI32Type()}); |
487 | |
488 | auto *argLowering = fir::getIntrinsicArgumentLowering("cshift" ); |
489 | llvm::SmallVector<fir::ExtendedValue, 3> args = |
490 | lowerArguments(cshift, inArgs, rewriter, argLowering); |
491 | |
492 | mlir::Type scalarResultType = |
493 | hlfir::getFortranElementType(cshift.getType()); |
494 | |
495 | auto [resultExv, mustBeFreed] = |
496 | fir::genIntrinsicCall(builder, loc, "cshift" , scalarResultType, args); |
497 | |
498 | processReturnValue(cshift, resultExv, mustBeFreed, builder, rewriter); |
499 | return mlir::success(); |
500 | } |
501 | }; |
502 | |
503 | class ReshapeOpConversion : public HlfirIntrinsicConversion<hlfir::ReshapeOp> { |
504 | using HlfirIntrinsicConversion<hlfir::ReshapeOp>::HlfirIntrinsicConversion; |
505 | |
506 | llvm::LogicalResult |
507 | matchAndRewrite(hlfir::ReshapeOp reshape, |
508 | mlir::PatternRewriter &rewriter) const override { |
509 | fir::FirOpBuilder builder{rewriter, reshape.getOperation()}; |
510 | const mlir::Location &loc = reshape->getLoc(); |
511 | |
512 | llvm::SmallVector<IntrinsicArgument, 4> inArgs; |
513 | mlir::Value array = reshape.getArray(); |
514 | inArgs.push_back({array, array.getType()}); |
515 | mlir::Value shape = reshape.getShape(); |
516 | inArgs.push_back({shape, shape.getType()}); |
517 | mlir::Type noneType = builder.getNoneType(); |
518 | mlir::Value pad = reshape.getPad(); |
519 | inArgs.push_back({pad, pad ? pad.getType() : noneType}); |
520 | mlir::Value order = reshape.getOrder(); |
521 | inArgs.push_back({order, order ? order.getType() : noneType}); |
522 | |
523 | auto *argLowering = fir::getIntrinsicArgumentLowering("reshape" ); |
524 | llvm::SmallVector<fir::ExtendedValue, 4> args = |
525 | lowerArguments(reshape, inArgs, rewriter, argLowering); |
526 | |
527 | mlir::Type scalarResultType = |
528 | hlfir::getFortranElementType(reshape.getType()); |
529 | |
530 | auto [resultExv, mustBeFreed] = |
531 | fir::genIntrinsicCall(builder, loc, "reshape" , scalarResultType, args); |
532 | |
533 | processReturnValue(reshape, resultExv, mustBeFreed, builder, rewriter); |
534 | return mlir::success(); |
535 | } |
536 | }; |
537 | |
538 | class LowerHLFIRIntrinsics |
539 | : public hlfir::impl::LowerHLFIRIntrinsicsBase<LowerHLFIRIntrinsics> { |
540 | public: |
541 | void runOnOperation() override { |
542 | mlir::ModuleOp module = this->getOperation(); |
543 | mlir::MLIRContext *context = &getContext(); |
544 | mlir::RewritePatternSet patterns(context); |
545 | patterns.insert< |
546 | MatmulOpConversion, MatmulTransposeOpConversion, AllOpConversion, |
547 | AnyOpConversion, SumOpConversion, ProductOpConversion, |
548 | TransposeOpConversion, CountOpConversion, DotProductOpConversion, |
549 | MaxvalOpConversion, MinvalOpConversion, MinlocOpConversion, |
550 | MaxlocOpConversion, CShiftOpConversion, ReshapeOpConversion>(context); |
551 | |
552 | // While conceptually this pass is performing dialect conversion, we use |
553 | // pattern rewrites here instead of dialect conversion because this pass |
554 | // looses array bounds from some of the expressions e.g. |
555 | // !hlfir.expr<2xi32> -> !hlfir.expr<?xi32> |
556 | // MLIR thinks this is a different type so dialect conversion fails. |
557 | // Pattern rewriting only requires that the resulting IR is still valid |
558 | mlir::GreedyRewriteConfig config; |
559 | // Prevent the pattern driver from merging blocks |
560 | config.setRegionSimplificationLevel( |
561 | mlir::GreedySimplifyRegionLevel::Disabled); |
562 | |
563 | if (mlir::failed( |
564 | mlir::applyPatternsGreedily(module, std::move(patterns), config))) { |
565 | mlir::emitError(mlir::UnknownLoc::get(context), |
566 | "failure in HLFIR intrinsic lowering" ); |
567 | signalPassFailure(); |
568 | } |
569 | } |
570 | }; |
571 | } // namespace |
572 | |