1//===- LowerHLFIRIntrinsics.cpp - Transformational intrinsics to FIR ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "flang/Optimizer/Builder/FIRBuilder.h"
10#include "flang/Optimizer/Builder/HLFIRTools.h"
11#include "flang/Optimizer/Builder/IntrinsicCall.h"
12#include "flang/Optimizer/Builder/Todo.h"
13#include "flang/Optimizer/Dialect/FIRDialect.h"
14#include "flang/Optimizer/Dialect/FIROps.h"
15#include "flang/Optimizer/Dialect/FIRType.h"
16#include "flang/Optimizer/Dialect/Support/FIRContext.h"
17#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
18#include "flang/Optimizer/HLFIR/HLFIROps.h"
19#include "flang/Optimizer/HLFIR/Passes.h"
20#include "mlir/IR/BuiltinDialect.h"
21#include "mlir/IR/MLIRContext.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/Pass/Pass.h"
24#include "mlir/Pass/PassManager.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26#include <optional>
27
28namespace hlfir {
29#define GEN_PASS_DEF_LOWERHLFIRINTRINSICS
30#include "flang/Optimizer/HLFIR/Passes.h.inc"
31} // namespace hlfir
32
33namespace {
34
35/// Base class for passes converting transformational intrinsic operations into
36/// runtime calls
37template <class OP>
38class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
39public:
40 explicit HlfirIntrinsicConversion(mlir::MLIRContext *ctx)
41 : mlir::OpRewritePattern<OP>{ctx} {
42 // required for cases where intrinsics are chained together e.g.
43 // matmul(matmul(a, b), c)
44 // because converting the inner operation then invalidates the
45 // outer operation: causing the pattern to apply recursively.
46 //
47 // This is safe because we always progress with each iteration. Circular
48 // applications of operations are not expressible in MLIR because we use
49 // an SSA form and one must become first. E.g.
50 // %a = hlfir.matmul %b %d
51 // %b = hlfir.matmul %a %d
52 // cannot be written.
53 // MSVC needs the this->
54 this->setHasBoundedRewriteRecursion(true);
55 }
56
57protected:
58 struct IntrinsicArgument {
59 mlir::Value val; // allowed to be null if the argument is absent
60 mlir::Type desiredType;
61 };
62
63 /// Lower the arguments to the intrinsic: adding necessary boxing and
64 /// conversion to match the signature of the intrinsic in the runtime library.
65 llvm::SmallVector<fir::ExtendedValue, 3>
66 lowerArguments(mlir::Operation *op,
67 const llvm::ArrayRef<IntrinsicArgument> &args,
68 mlir::PatternRewriter &rewriter,
69 const fir::IntrinsicArgumentLoweringRules *argLowering) const {
70 mlir::Location loc = op->getLoc();
71 fir::FirOpBuilder builder{rewriter, op};
72
73 llvm::SmallVector<fir::ExtendedValue, 3> ret;
74 llvm::SmallVector<std::function<void()>, 2> cleanupFns;
75
76 for (size_t i = 0; i < args.size(); ++i) {
77 mlir::Value arg = args[i].val;
78 mlir::Type desiredType = args[i].desiredType;
79 if (!arg) {
80 ret.emplace_back(fir::getAbsentIntrinsicArgument());
81 continue;
82 }
83 hlfir::Entity entity{arg};
84
85 fir::ArgLoweringRule argRules =
86 fir::lowerIntrinsicArgumentAs(*argLowering, i);
87 switch (argRules.lowerAs) {
88 case fir::LowerIntrinsicArgAs::Value: {
89 if (args[i].desiredType != arg.getType()) {
90 arg = builder.createConvert(loc, desiredType, arg);
91 entity = hlfir::Entity{arg};
92 }
93 auto [exv, cleanup] = hlfir::convertToValue(loc, builder, entity);
94 if (cleanup)
95 cleanupFns.push_back(*cleanup);
96 ret.emplace_back(exv);
97 } break;
98 case fir::LowerIntrinsicArgAs::Addr: {
99 auto [exv, cleanup] =
100 hlfir::convertToAddress(loc, builder, entity, desiredType);
101 if (cleanup)
102 cleanupFns.push_back(*cleanup);
103 ret.emplace_back(exv);
104 } break;
105 case fir::LowerIntrinsicArgAs::Box: {
106 auto [box, cleanup] =
107 hlfir::convertToBox(loc, builder, entity, desiredType);
108 if (cleanup)
109 cleanupFns.push_back(*cleanup);
110 ret.emplace_back(box);
111 } break;
112 case fir::LowerIntrinsicArgAs::Inquired: {
113 if (args[i].desiredType != arg.getType()) {
114 arg = builder.createConvert(loc, desiredType, arg);
115 entity = hlfir::Entity{arg};
116 }
117 // Place hlfir.expr in memory, and unbox fir.boxchar. Other entities
118 // are translated to fir::ExtendedValue without transofrmation (notably,
119 // pointers/allocatable are not dereferenced).
120 // TODO: once lowering to FIR retires, UBOUND and LBOUND can be
121 // simplified since the fir.box lowered here are now guarenteed to
122 // contain the local lower bounds thanks to the hlfir.declare (the extra
123 // rebox can be removed).
124 // When taking arguments as descriptors, the runtime expect absent
125 // OPTIONAL to be a nullptr to a descriptor, lowering has already
126 // prepared such descriptors as needed, hence set
127 // keepScalarOptionalBoxed to avoid building descriptors with a null
128 // address for them.
129 auto [exv, cleanup] = hlfir::translateToExtendedValue(
130 loc, builder, entity, /*contiguous=*/false,
131 /*keepScalarOptionalBoxed=*/true);
132 if (cleanup)
133 cleanupFns.push_back(*cleanup);
134 ret.emplace_back(exv);
135 } break;
136 }
137 }
138
139 if (cleanupFns.size()) {
140 auto oldInsertionPoint = builder.saveInsertionPoint();
141 builder.setInsertionPointAfter(op);
142 for (std::function<void()> cleanup : cleanupFns)
143 cleanup();
144 builder.restoreInsertionPoint(oldInsertionPoint);
145 }
146
147 return ret;
148 }
149
150 void processReturnValue(mlir::Operation *op,
151 const fir::ExtendedValue &resultExv, bool mustBeFreed,
152 fir::FirOpBuilder &builder,
153 mlir::PatternRewriter &rewriter) const {
154 mlir::Location loc = op->getLoc();
155
156 mlir::Value firBase = fir::getBase(resultExv);
157 mlir::Type firBaseTy = firBase.getType();
158
159 std::optional<hlfir::EntityWithAttributes> resultEntity;
160 if (fir::isa_trivial(firBaseTy)) {
161 // Some intrinsics return i1 when the original operation
162 // produces fir.logical<>, so we may need to cast it.
163 firBase = builder.createConvert(loc, op->getResult(0).getType(), firBase);
164 resultEntity = hlfir::EntityWithAttributes{firBase};
165 } else {
166 resultEntity =
167 hlfir::genDeclare(loc, builder, resultExv, ".tmp.intrinsic_result",
168 fir::FortranVariableFlagsAttr{});
169 }
170
171 if (resultEntity->isVariable()) {
172 hlfir::AsExprOp asExpr = builder.create<hlfir::AsExprOp>(
173 loc, *resultEntity, builder.createBool(loc, mustBeFreed));
174 resultEntity = hlfir::EntityWithAttributes{asExpr.getResult()};
175 }
176
177 mlir::Value base = resultEntity->getBase();
178 if (!mlir::isa<hlfir::ExprType>(base.getType())) {
179 for (mlir::Operation *use : op->getResult(0).getUsers()) {
180 if (mlir::isa<hlfir::DestroyOp>(use))
181 rewriter.eraseOp(use);
182 }
183 }
184
185 rewriter.replaceOp(op, base);
186 }
187};
188
189// Given an integer or array of integer type, calculate the Kind parameter from
190// the width for use in runtime intrinsic calls.
191static unsigned getKindForType(mlir::Type ty) {
192 mlir::Type eltty = hlfir::getFortranElementType(ty);
193 unsigned width = mlir::cast<mlir::IntegerType>(eltty).getWidth();
194 return width / 8;
195}
196
197template <class OP>
198class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
199 using HlfirIntrinsicConversion<OP>::HlfirIntrinsicConversion;
200 using IntrinsicArgument =
201 typename HlfirIntrinsicConversion<OP>::IntrinsicArgument;
202 using HlfirIntrinsicConversion<OP>::lowerArguments;
203 using HlfirIntrinsicConversion<OP>::processReturnValue;
204
205protected:
206 auto buildNumericalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
207 mlir::PatternRewriter &rewriter,
208 std::string opName) const {
209 llvm::SmallVector<IntrinsicArgument, 3> inArgs;
210 inArgs.push_back({operation.getArray(), operation.getArray().getType()});
211 inArgs.push_back({operation.getDim(), i32});
212 inArgs.push_back({operation.getMask(), logicalType});
213 auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
214 return lowerArguments(operation, inArgs, rewriter, argLowering);
215 };
216
217 auto buildMinMaxLocArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
218 mlir::PatternRewriter &rewriter, std::string opName,
219 fir::FirOpBuilder builder) const {
220 llvm::SmallVector<IntrinsicArgument, 3> inArgs;
221 inArgs.push_back({operation.getArray(), operation.getArray().getType()});
222 inArgs.push_back({operation.getDim(), i32});
223 inArgs.push_back({operation.getMask(), logicalType});
224 mlir::Value kind = builder.createIntegerConstant(
225 operation->getLoc(), i32, getKindForType(operation.getType()));
226 inArgs.push_back({kind, i32});
227 inArgs.push_back({operation.getBack(), i32});
228 auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
229 return lowerArguments(operation, inArgs, rewriter, argLowering);
230 };
231
232 auto buildLogicalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
233 mlir::PatternRewriter &rewriter,
234 std::string opName) const {
235 llvm::SmallVector<IntrinsicArgument, 2> inArgs;
236 inArgs.push_back({operation.getMask(), logicalType});
237 inArgs.push_back({operation.getDim(), i32});
238 auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
239 return lowerArguments(operation, inArgs, rewriter, argLowering);
240 };
241
242public:
243 llvm::LogicalResult
244 matchAndRewrite(OP operation,
245 mlir::PatternRewriter &rewriter) const override {
246 std::string opName;
247 if constexpr (std::is_same_v<OP, hlfir::SumOp>) {
248 opName = "sum";
249 } else if constexpr (std::is_same_v<OP, hlfir::ProductOp>) {
250 opName = "product";
251 } else if constexpr (std::is_same_v<OP, hlfir::MaxvalOp>) {
252 opName = "maxval";
253 } else if constexpr (std::is_same_v<OP, hlfir::MinvalOp>) {
254 opName = "minval";
255 } else if constexpr (std::is_same_v<OP, hlfir::MinlocOp>) {
256 opName = "minloc";
257 } else if constexpr (std::is_same_v<OP, hlfir::MaxlocOp>) {
258 opName = "maxloc";
259 } else if constexpr (std::is_same_v<OP, hlfir::AnyOp>) {
260 opName = "any";
261 } else if constexpr (std::is_same_v<OP, hlfir::AllOp>) {
262 opName = "all";
263 } else {
264 return mlir::failure();
265 }
266
267 fir::FirOpBuilder builder{rewriter, operation.getOperation()};
268 const mlir::Location &loc = operation->getLoc();
269
270 mlir::Type i32 = builder.getI32Type();
271 mlir::Type logicalType = fir::LogicalType::get(
272 builder.getContext(), builder.getKindMap().defaultLogicalKind());
273
274 llvm::SmallVector<fir::ExtendedValue, 0> args;
275
276 if constexpr (std::is_same_v<OP, hlfir::SumOp> ||
277 std::is_same_v<OP, hlfir::ProductOp> ||
278 std::is_same_v<OP, hlfir::MaxvalOp> ||
279 std::is_same_v<OP, hlfir::MinvalOp>) {
280 args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName);
281 } else if constexpr (std::is_same_v<OP, hlfir::MinlocOp> ||
282 std::is_same_v<OP, hlfir::MaxlocOp>) {
283 args = buildMinMaxLocArgs(operation, i32, logicalType, rewriter, opName,
284 builder);
285 } else {
286 args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName);
287 }
288
289 mlir::Type scalarResultType =
290 hlfir::getFortranElementType(operation.getType());
291
292 auto [resultExv, mustBeFreed] =
293 fir::genIntrinsicCall(builder, loc, opName, scalarResultType, args);
294
295 processReturnValue(operation, resultExv, mustBeFreed, builder, rewriter);
296 return mlir::success();
297 }
298};
299
300using SumOpConversion = HlfirReductionIntrinsicConversion<hlfir::SumOp>;
301
302using ProductOpConversion = HlfirReductionIntrinsicConversion<hlfir::ProductOp>;
303
304using MaxvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxvalOp>;
305
306using MinvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinvalOp>;
307
308using MinlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinlocOp>;
309
310using MaxlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxlocOp>;
311
312using AnyOpConversion = HlfirReductionIntrinsicConversion<hlfir::AnyOp>;
313
314using AllOpConversion = HlfirReductionIntrinsicConversion<hlfir::AllOp>;
315
316struct CountOpConversion : public HlfirIntrinsicConversion<hlfir::CountOp> {
317 using HlfirIntrinsicConversion<hlfir::CountOp>::HlfirIntrinsicConversion;
318
319 llvm::LogicalResult
320 matchAndRewrite(hlfir::CountOp count,
321 mlir::PatternRewriter &rewriter) const override {
322 fir::FirOpBuilder builder{rewriter, count.getOperation()};
323 const mlir::Location &loc = count->getLoc();
324
325 mlir::Type i32 = builder.getI32Type();
326 mlir::Type logicalType = fir::LogicalType::get(
327 builder.getContext(), builder.getKindMap().defaultLogicalKind());
328
329 llvm::SmallVector<IntrinsicArgument, 3> inArgs;
330 inArgs.push_back({count.getMask(), logicalType});
331 inArgs.push_back({count.getDim(), i32});
332 mlir::Value kind = builder.createIntegerConstant(
333 count->getLoc(), i32, getKindForType(count.getType()));
334 inArgs.push_back({kind, i32});
335
336 auto *argLowering = fir::getIntrinsicArgumentLowering("count");
337 llvm::SmallVector<fir::ExtendedValue, 3> args =
338 lowerArguments(count, inArgs, rewriter, argLowering);
339
340 mlir::Type scalarResultType = hlfir::getFortranElementType(count.getType());
341
342 auto [resultExv, mustBeFreed] =
343 fir::genIntrinsicCall(builder, loc, "count", scalarResultType, args);
344
345 processReturnValue(count, resultExv, mustBeFreed, builder, rewriter);
346 return mlir::success();
347 }
348};
349
350struct MatmulOpConversion : public HlfirIntrinsicConversion<hlfir::MatmulOp> {
351 using HlfirIntrinsicConversion<hlfir::MatmulOp>::HlfirIntrinsicConversion;
352
353 llvm::LogicalResult
354 matchAndRewrite(hlfir::MatmulOp matmul,
355 mlir::PatternRewriter &rewriter) const override {
356 fir::FirOpBuilder builder{rewriter, matmul.getOperation()};
357 const mlir::Location &loc = matmul->getLoc();
358
359 mlir::Value lhs = matmul.getLhs();
360 mlir::Value rhs = matmul.getRhs();
361 llvm::SmallVector<IntrinsicArgument, 2> inArgs;
362 inArgs.push_back({lhs, lhs.getType()});
363 inArgs.push_back({rhs, rhs.getType()});
364
365 auto *argLowering = fir::getIntrinsicArgumentLowering("matmul");
366 llvm::SmallVector<fir::ExtendedValue, 2> args =
367 lowerArguments(matmul, inArgs, rewriter, argLowering);
368
369 mlir::Type scalarResultType =
370 hlfir::getFortranElementType(matmul.getType());
371
372 auto [resultExv, mustBeFreed] =
373 fir::genIntrinsicCall(builder, loc, "matmul", scalarResultType, args);
374
375 processReturnValue(matmul, resultExv, mustBeFreed, builder, rewriter);
376 return mlir::success();
377 }
378};
379
380struct DotProductOpConversion
381 : public HlfirIntrinsicConversion<hlfir::DotProductOp> {
382 using HlfirIntrinsicConversion<hlfir::DotProductOp>::HlfirIntrinsicConversion;
383
384 llvm::LogicalResult
385 matchAndRewrite(hlfir::DotProductOp dotProduct,
386 mlir::PatternRewriter &rewriter) const override {
387 fir::FirOpBuilder builder{rewriter, dotProduct.getOperation()};
388 const mlir::Location &loc = dotProduct->getLoc();
389
390 mlir::Value lhs = dotProduct.getLhs();
391 mlir::Value rhs = dotProduct.getRhs();
392 llvm::SmallVector<IntrinsicArgument, 2> inArgs;
393 inArgs.push_back({lhs, lhs.getType()});
394 inArgs.push_back({rhs, rhs.getType()});
395
396 auto *argLowering = fir::getIntrinsicArgumentLowering("dot_product");
397 llvm::SmallVector<fir::ExtendedValue, 2> args =
398 lowerArguments(dotProduct, inArgs, rewriter, argLowering);
399
400 mlir::Type scalarResultType =
401 hlfir::getFortranElementType(dotProduct.getType());
402
403 auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
404 builder, loc, "dot_product", scalarResultType, args);
405
406 processReturnValue(dotProduct, resultExv, mustBeFreed, builder, rewriter);
407 return mlir::success();
408 }
409};
410
411class TransposeOpConversion
412 : public HlfirIntrinsicConversion<hlfir::TransposeOp> {
413 using HlfirIntrinsicConversion<hlfir::TransposeOp>::HlfirIntrinsicConversion;
414
415 llvm::LogicalResult
416 matchAndRewrite(hlfir::TransposeOp transpose,
417 mlir::PatternRewriter &rewriter) const override {
418 fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
419 const mlir::Location &loc = transpose->getLoc();
420
421 mlir::Value arg = transpose.getArray();
422 llvm::SmallVector<IntrinsicArgument, 1> inArgs;
423 inArgs.push_back({arg, arg.getType()});
424
425 auto *argLowering = fir::getIntrinsicArgumentLowering("transpose");
426 llvm::SmallVector<fir::ExtendedValue, 1> args =
427 lowerArguments(transpose, inArgs, rewriter, argLowering);
428
429 mlir::Type scalarResultType =
430 hlfir::getFortranElementType(transpose.getType());
431
432 auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
433 builder, loc, "transpose", scalarResultType, args);
434
435 processReturnValue(transpose, resultExv, mustBeFreed, builder, rewriter);
436 return mlir::success();
437 }
438};
439
440struct MatmulTransposeOpConversion
441 : public HlfirIntrinsicConversion<hlfir::MatmulTransposeOp> {
442 using HlfirIntrinsicConversion<
443 hlfir::MatmulTransposeOp>::HlfirIntrinsicConversion;
444
445 llvm::LogicalResult
446 matchAndRewrite(hlfir::MatmulTransposeOp multranspose,
447 mlir::PatternRewriter &rewriter) const override {
448 fir::FirOpBuilder builder{rewriter, multranspose.getOperation()};
449 const mlir::Location &loc = multranspose->getLoc();
450
451 mlir::Value lhs = multranspose.getLhs();
452 mlir::Value rhs = multranspose.getRhs();
453 llvm::SmallVector<IntrinsicArgument, 2> inArgs;
454 inArgs.push_back({lhs, lhs.getType()});
455 inArgs.push_back({rhs, rhs.getType()});
456
457 auto *argLowering = fir::getIntrinsicArgumentLowering("matmul");
458 llvm::SmallVector<fir::ExtendedValue, 2> args =
459 lowerArguments(multranspose, inArgs, rewriter, argLowering);
460
461 mlir::Type scalarResultType =
462 hlfir::getFortranElementType(multranspose.getType());
463
464 auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
465 builder, loc, "matmul_transpose", scalarResultType, args);
466
467 processReturnValue(multranspose, resultExv, mustBeFreed, builder, rewriter);
468 return mlir::success();
469 }
470};
471
472class CShiftOpConversion : public HlfirIntrinsicConversion<hlfir::CShiftOp> {
473 using HlfirIntrinsicConversion<hlfir::CShiftOp>::HlfirIntrinsicConversion;
474
475 llvm::LogicalResult
476 matchAndRewrite(hlfir::CShiftOp cshift,
477 mlir::PatternRewriter &rewriter) const override {
478 fir::FirOpBuilder builder{rewriter, cshift.getOperation()};
479 const mlir::Location &loc = cshift->getLoc();
480
481 llvm::SmallVector<IntrinsicArgument, 3> inArgs;
482 mlir::Value array = cshift.getArray();
483 inArgs.push_back({array, array.getType()});
484 mlir::Value shift = cshift.getShift();
485 inArgs.push_back({shift, shift.getType()});
486 inArgs.push_back({cshift.getDim(), builder.getI32Type()});
487
488 auto *argLowering = fir::getIntrinsicArgumentLowering("cshift");
489 llvm::SmallVector<fir::ExtendedValue, 3> args =
490 lowerArguments(cshift, inArgs, rewriter, argLowering);
491
492 mlir::Type scalarResultType =
493 hlfir::getFortranElementType(cshift.getType());
494
495 auto [resultExv, mustBeFreed] =
496 fir::genIntrinsicCall(builder, loc, "cshift", scalarResultType, args);
497
498 processReturnValue(cshift, resultExv, mustBeFreed, builder, rewriter);
499 return mlir::success();
500 }
501};
502
503class ReshapeOpConversion : public HlfirIntrinsicConversion<hlfir::ReshapeOp> {
504 using HlfirIntrinsicConversion<hlfir::ReshapeOp>::HlfirIntrinsicConversion;
505
506 llvm::LogicalResult
507 matchAndRewrite(hlfir::ReshapeOp reshape,
508 mlir::PatternRewriter &rewriter) const override {
509 fir::FirOpBuilder builder{rewriter, reshape.getOperation()};
510 const mlir::Location &loc = reshape->getLoc();
511
512 llvm::SmallVector<IntrinsicArgument, 4> inArgs;
513 mlir::Value array = reshape.getArray();
514 inArgs.push_back({array, array.getType()});
515 mlir::Value shape = reshape.getShape();
516 inArgs.push_back({shape, shape.getType()});
517 mlir::Type noneType = builder.getNoneType();
518 mlir::Value pad = reshape.getPad();
519 inArgs.push_back({pad, pad ? pad.getType() : noneType});
520 mlir::Value order = reshape.getOrder();
521 inArgs.push_back({order, order ? order.getType() : noneType});
522
523 auto *argLowering = fir::getIntrinsicArgumentLowering("reshape");
524 llvm::SmallVector<fir::ExtendedValue, 4> args =
525 lowerArguments(reshape, inArgs, rewriter, argLowering);
526
527 mlir::Type scalarResultType =
528 hlfir::getFortranElementType(reshape.getType());
529
530 auto [resultExv, mustBeFreed] =
531 fir::genIntrinsicCall(builder, loc, "reshape", scalarResultType, args);
532
533 processReturnValue(reshape, resultExv, mustBeFreed, builder, rewriter);
534 return mlir::success();
535 }
536};
537
538class LowerHLFIRIntrinsics
539 : public hlfir::impl::LowerHLFIRIntrinsicsBase<LowerHLFIRIntrinsics> {
540public:
541 void runOnOperation() override {
542 mlir::ModuleOp module = this->getOperation();
543 mlir::MLIRContext *context = &getContext();
544 mlir::RewritePatternSet patterns(context);
545 patterns.insert<
546 MatmulOpConversion, MatmulTransposeOpConversion, AllOpConversion,
547 AnyOpConversion, SumOpConversion, ProductOpConversion,
548 TransposeOpConversion, CountOpConversion, DotProductOpConversion,
549 MaxvalOpConversion, MinvalOpConversion, MinlocOpConversion,
550 MaxlocOpConversion, CShiftOpConversion, ReshapeOpConversion>(context);
551
552 // While conceptually this pass is performing dialect conversion, we use
553 // pattern rewrites here instead of dialect conversion because this pass
554 // looses array bounds from some of the expressions e.g.
555 // !hlfir.expr<2xi32> -> !hlfir.expr<?xi32>
556 // MLIR thinks this is a different type so dialect conversion fails.
557 // Pattern rewriting only requires that the resulting IR is still valid
558 mlir::GreedyRewriteConfig config;
559 // Prevent the pattern driver from merging blocks
560 config.setRegionSimplificationLevel(
561 mlir::GreedySimplifyRegionLevel::Disabled);
562
563 if (mlir::failed(
564 mlir::applyPatternsGreedily(module, std::move(patterns), config))) {
565 mlir::emitError(mlir::UnknownLoc::get(context),
566 "failure in HLFIR intrinsic lowering");
567 signalPassFailure();
568 }
569 }
570};
571} // namespace
572

source code of flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp