1//===- LowerHLFIRIntrinsics.cpp - Transformational intrinsics to FIR ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "flang/Optimizer/Builder/FIRBuilder.h"
10#include "flang/Optimizer/Builder/HLFIRTools.h"
11#include "flang/Optimizer/Builder/IntrinsicCall.h"
12#include "flang/Optimizer/Builder/Todo.h"
13#include "flang/Optimizer/Dialect/FIRDialect.h"
14#include "flang/Optimizer/Dialect/FIROps.h"
15#include "flang/Optimizer/Dialect/FIRType.h"
16#include "flang/Optimizer/Dialect/Support/FIRContext.h"
17#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
18#include "flang/Optimizer/HLFIR/HLFIROps.h"
19#include "flang/Optimizer/HLFIR/Passes.h"
20#include "mlir/IR/BuiltinDialect.h"
21#include "mlir/IR/MLIRContext.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/Pass/Pass.h"
24#include "mlir/Pass/PassManager.h"
25#include "mlir/Support/LogicalResult.h"
26#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27#include <optional>
28
29namespace hlfir {
30#define GEN_PASS_DEF_LOWERHLFIRINTRINSICS
31#include "flang/Optimizer/HLFIR/Passes.h.inc"
32} // namespace hlfir
33
34namespace {
35
36/// Base class for passes converting transformational intrinsic operations into
37/// runtime calls
38template <class OP>
39class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
40public:
41 explicit HlfirIntrinsicConversion(mlir::MLIRContext *ctx)
42 : mlir::OpRewritePattern<OP>{ctx} {
43 // required for cases where intrinsics are chained together e.g.
44 // matmul(matmul(a, b), c)
45 // because converting the inner operation then invalidates the
46 // outer operation: causing the pattern to apply recursively.
47 //
48 // This is safe because we always progress with each iteration. Circular
49 // applications of operations are not expressible in MLIR because we use
50 // an SSA form and one must become first. E.g.
51 // %a = hlfir.matmul %b %d
52 // %b = hlfir.matmul %a %d
53 // cannot be written.
54 // MSVC needs the this->
55 this->setHasBoundedRewriteRecursion(true);
56 }
57
58protected:
59 struct IntrinsicArgument {
60 mlir::Value val; // allowed to be null if the argument is absent
61 mlir::Type desiredType;
62 };
63
64 /// Lower the arguments to the intrinsic: adding necessary boxing and
65 /// conversion to match the signature of the intrinsic in the runtime library.
66 llvm::SmallVector<fir::ExtendedValue, 3>
67 lowerArguments(mlir::Operation *op,
68 const llvm::ArrayRef<IntrinsicArgument> &args,
69 mlir::PatternRewriter &rewriter,
70 const fir::IntrinsicArgumentLoweringRules *argLowering) const {
71 mlir::Location loc = op->getLoc();
72 fir::FirOpBuilder builder{rewriter, op};
73
74 llvm::SmallVector<fir::ExtendedValue, 3> ret;
75 llvm::SmallVector<std::function<void()>, 2> cleanupFns;
76
77 for (size_t i = 0; i < args.size(); ++i) {
78 mlir::Value arg = args[i].val;
79 mlir::Type desiredType = args[i].desiredType;
80 if (!arg) {
81 ret.emplace_back(fir::getAbsentIntrinsicArgument());
82 continue;
83 }
84 hlfir::Entity entity{arg};
85
86 fir::ArgLoweringRule argRules =
87 fir::lowerIntrinsicArgumentAs(*argLowering, i);
88 switch (argRules.lowerAs) {
89 case fir::LowerIntrinsicArgAs::Value: {
90 if (args[i].desiredType != arg.getType()) {
91 arg = builder.createConvert(loc, desiredType, arg);
92 entity = hlfir::Entity{arg};
93 }
94 auto [exv, cleanup] = hlfir::convertToValue(loc, builder, entity);
95 if (cleanup)
96 cleanupFns.push_back(*cleanup);
97 ret.emplace_back(exv);
98 } break;
99 case fir::LowerIntrinsicArgAs::Addr: {
100 auto [exv, cleanup] =
101 hlfir::convertToAddress(loc, builder, entity, desiredType);
102 if (cleanup)
103 cleanupFns.push_back(*cleanup);
104 ret.emplace_back(exv);
105 } break;
106 case fir::LowerIntrinsicArgAs::Box: {
107 auto [box, cleanup] =
108 hlfir::convertToBox(loc, builder, entity, desiredType);
109 if (cleanup)
110 cleanupFns.push_back(*cleanup);
111 ret.emplace_back(box);
112 } break;
113 case fir::LowerIntrinsicArgAs::Inquired: {
114 if (args[i].desiredType != arg.getType()) {
115 arg = builder.createConvert(loc, desiredType, arg);
116 entity = hlfir::Entity{arg};
117 }
118 // Place hlfir.expr in memory, and unbox fir.boxchar. Other entities
119 // are translated to fir::ExtendedValue without transofrmation (notably,
120 // pointers/allocatable are not dereferenced).
121 // TODO: once lowering to FIR retires, UBOUND and LBOUND can be
122 // simplified since the fir.box lowered here are now guarenteed to
123 // contain the local lower bounds thanks to the hlfir.declare (the extra
124 // rebox can be removed).
125 auto [exv, cleanup] =
126 hlfir::translateToExtendedValue(loc, builder, entity);
127 if (cleanup)
128 cleanupFns.push_back(*cleanup);
129 ret.emplace_back(exv);
130 } break;
131 }
132 }
133
134 if (cleanupFns.size()) {
135 auto oldInsertionPoint = builder.saveInsertionPoint();
136 builder.setInsertionPointAfter(op);
137 for (std::function<void()> cleanup : cleanupFns)
138 cleanup();
139 builder.restoreInsertionPoint(oldInsertionPoint);
140 }
141
142 return ret;
143 }
144
145 void processReturnValue(mlir::Operation *op,
146 const fir::ExtendedValue &resultExv, bool mustBeFreed,
147 fir::FirOpBuilder &builder,
148 mlir::PatternRewriter &rewriter) const {
149 mlir::Location loc = op->getLoc();
150
151 mlir::Value firBase = fir::getBase(resultExv);
152 mlir::Type firBaseTy = firBase.getType();
153
154 std::optional<hlfir::EntityWithAttributes> resultEntity;
155 if (fir::isa_trivial(firBaseTy)) {
156 // Some intrinsics return i1 when the original operation
157 // produces fir.logical<>, so we may need to cast it.
158 firBase = builder.createConvert(loc, op->getResult(0).getType(), firBase);
159 resultEntity = hlfir::EntityWithAttributes{firBase};
160 } else {
161 resultEntity =
162 hlfir::genDeclare(loc, builder, resultExv, ".tmp.intrinsic_result",
163 fir::FortranVariableFlagsAttr{});
164 }
165
166 if (resultEntity->isVariable()) {
167 hlfir::AsExprOp asExpr = builder.create<hlfir::AsExprOp>(
168 loc, *resultEntity, builder.createBool(loc, mustBeFreed));
169 resultEntity = hlfir::EntityWithAttributes{asExpr.getResult()};
170 }
171
172 mlir::Value base = resultEntity->getBase();
173 if (!mlir::isa<hlfir::ExprType>(base.getType())) {
174 for (mlir::Operation *use : op->getResult(0).getUsers()) {
175 if (mlir::isa<hlfir::DestroyOp>(use))
176 rewriter.eraseOp(use);
177 }
178 }
179
180 rewriter.replaceOp(op, base);
181 }
182};
183
184// Given an integer or array of integer type, calculate the Kind parameter from
185// the width for use in runtime intrinsic calls.
186static unsigned getKindForType(mlir::Type ty) {
187 mlir::Type eltty = hlfir::getFortranElementType(ty);
188 unsigned width = eltty.cast<mlir::IntegerType>().getWidth();
189 return width / 8;
190}
191
192template <class OP>
193class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
194 using HlfirIntrinsicConversion<OP>::HlfirIntrinsicConversion;
195 using IntrinsicArgument =
196 typename HlfirIntrinsicConversion<OP>::IntrinsicArgument;
197 using HlfirIntrinsicConversion<OP>::lowerArguments;
198 using HlfirIntrinsicConversion<OP>::processReturnValue;
199
200protected:
201 auto buildNumericalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
202 mlir::PatternRewriter &rewriter,
203 std::string opName) const {
204 llvm::SmallVector<IntrinsicArgument, 3> inArgs;
205 inArgs.push_back({operation.getArray(), operation.getArray().getType()});
206 inArgs.push_back({operation.getDim(), i32});
207 inArgs.push_back({operation.getMask(), logicalType});
208 auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
209 return lowerArguments(operation, inArgs, rewriter, argLowering);
210 };
211
212 auto buildMinMaxLocArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
213 mlir::PatternRewriter &rewriter, std::string opName,
214 fir::FirOpBuilder builder) const {
215 llvm::SmallVector<IntrinsicArgument, 3> inArgs;
216 inArgs.push_back({operation.getArray(), operation.getArray().getType()});
217 inArgs.push_back({operation.getDim(), i32});
218 inArgs.push_back({operation.getMask(), logicalType});
219 mlir::Value kind = builder.createIntegerConstant(
220 operation->getLoc(), i32, getKindForType(operation.getType()));
221 inArgs.push_back({kind, i32});
222 inArgs.push_back({operation.getBack(), i32});
223 auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
224 return lowerArguments(operation, inArgs, rewriter, argLowering);
225 };
226
227 auto buildLogicalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
228 mlir::PatternRewriter &rewriter,
229 std::string opName) const {
230 llvm::SmallVector<IntrinsicArgument, 2> inArgs;
231 inArgs.push_back({operation.getMask(), logicalType});
232 inArgs.push_back({operation.getDim(), i32});
233 auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
234 return lowerArguments(operation, inArgs, rewriter, argLowering);
235 };
236
237public:
238 mlir::LogicalResult
239 matchAndRewrite(OP operation,
240 mlir::PatternRewriter &rewriter) const override {
241 std::string opName;
242 if constexpr (std::is_same_v<OP, hlfir::SumOp>) {
243 opName = "sum";
244 } else if constexpr (std::is_same_v<OP, hlfir::ProductOp>) {
245 opName = "product";
246 } else if constexpr (std::is_same_v<OP, hlfir::MaxvalOp>) {
247 opName = "maxval";
248 } else if constexpr (std::is_same_v<OP, hlfir::MinvalOp>) {
249 opName = "minval";
250 } else if constexpr (std::is_same_v<OP, hlfir::MinlocOp>) {
251 opName = "minloc";
252 } else if constexpr (std::is_same_v<OP, hlfir::MaxlocOp>) {
253 opName = "maxloc";
254 } else if constexpr (std::is_same_v<OP, hlfir::AnyOp>) {
255 opName = "any";
256 } else if constexpr (std::is_same_v<OP, hlfir::AllOp>) {
257 opName = "all";
258 } else {
259 return mlir::failure();
260 }
261
262 fir::FirOpBuilder builder{rewriter, operation.getOperation()};
263 const mlir::Location &loc = operation->getLoc();
264
265 mlir::Type i32 = builder.getI32Type();
266 mlir::Type logicalType = fir::LogicalType::get(
267 builder.getContext(), builder.getKindMap().defaultLogicalKind());
268
269 llvm::SmallVector<fir::ExtendedValue, 0> args;
270
271 if constexpr (std::is_same_v<OP, hlfir::SumOp> ||
272 std::is_same_v<OP, hlfir::ProductOp> ||
273 std::is_same_v<OP, hlfir::MaxvalOp> ||
274 std::is_same_v<OP, hlfir::MinvalOp>) {
275 args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName);
276 } else if constexpr (std::is_same_v<OP, hlfir::MinlocOp> ||
277 std::is_same_v<OP, hlfir::MaxlocOp>) {
278 args = buildMinMaxLocArgs(operation, i32, logicalType, rewriter, opName,
279 builder);
280 } else {
281 args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName);
282 }
283
284 mlir::Type scalarResultType =
285 hlfir::getFortranElementType(operation.getType());
286
287 auto [resultExv, mustBeFreed] =
288 fir::genIntrinsicCall(builder, loc, opName, scalarResultType, args);
289
290 processReturnValue(operation, resultExv, mustBeFreed, builder, rewriter);
291 return mlir::success();
292 }
293};
294
295using SumOpConversion = HlfirReductionIntrinsicConversion<hlfir::SumOp>;
296
297using ProductOpConversion = HlfirReductionIntrinsicConversion<hlfir::ProductOp>;
298
299using MaxvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxvalOp>;
300
301using MinvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinvalOp>;
302
303using MinlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinlocOp>;
304
305using MaxlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxlocOp>;
306
307using AnyOpConversion = HlfirReductionIntrinsicConversion<hlfir::AnyOp>;
308
309using AllOpConversion = HlfirReductionIntrinsicConversion<hlfir::AllOp>;
310
311struct CountOpConversion : public HlfirIntrinsicConversion<hlfir::CountOp> {
312 using HlfirIntrinsicConversion<hlfir::CountOp>::HlfirIntrinsicConversion;
313
314 mlir::LogicalResult
315 matchAndRewrite(hlfir::CountOp count,
316 mlir::PatternRewriter &rewriter) const override {
317 fir::FirOpBuilder builder{rewriter, count.getOperation()};
318 const mlir::Location &loc = count->getLoc();
319
320 mlir::Type i32 = builder.getI32Type();
321 mlir::Type logicalType = fir::LogicalType::get(
322 builder.getContext(), builder.getKindMap().defaultLogicalKind());
323
324 llvm::SmallVector<IntrinsicArgument, 3> inArgs;
325 inArgs.push_back({count.getMask(), logicalType});
326 inArgs.push_back({count.getDim(), i32});
327 mlir::Value kind = builder.createIntegerConstant(
328 count->getLoc(), i32, getKindForType(count.getType()));
329 inArgs.push_back({kind, i32});
330
331 auto *argLowering = fir::getIntrinsicArgumentLowering("count");
332 llvm::SmallVector<fir::ExtendedValue, 3> args =
333 lowerArguments(count, inArgs, rewriter, argLowering);
334
335 mlir::Type scalarResultType = hlfir::getFortranElementType(count.getType());
336
337 auto [resultExv, mustBeFreed] =
338 fir::genIntrinsicCall(builder, loc, "count", scalarResultType, args);
339
340 processReturnValue(count, resultExv, mustBeFreed, builder, rewriter);
341 return mlir::success();
342 }
343};
344
345struct MatmulOpConversion : public HlfirIntrinsicConversion<hlfir::MatmulOp> {
346 using HlfirIntrinsicConversion<hlfir::MatmulOp>::HlfirIntrinsicConversion;
347
348 mlir::LogicalResult
349 matchAndRewrite(hlfir::MatmulOp matmul,
350 mlir::PatternRewriter &rewriter) const override {
351 fir::FirOpBuilder builder{rewriter, matmul.getOperation()};
352 const mlir::Location &loc = matmul->getLoc();
353
354 mlir::Value lhs = matmul.getLhs();
355 mlir::Value rhs = matmul.getRhs();
356 llvm::SmallVector<IntrinsicArgument, 2> inArgs;
357 inArgs.push_back({lhs, lhs.getType()});
358 inArgs.push_back({rhs, rhs.getType()});
359
360 auto *argLowering = fir::getIntrinsicArgumentLowering("matmul");
361 llvm::SmallVector<fir::ExtendedValue, 2> args =
362 lowerArguments(matmul, inArgs, rewriter, argLowering);
363
364 mlir::Type scalarResultType =
365 hlfir::getFortranElementType(matmul.getType());
366
367 auto [resultExv, mustBeFreed] =
368 fir::genIntrinsicCall(builder, loc, "matmul", scalarResultType, args);
369
370 processReturnValue(matmul, resultExv, mustBeFreed, builder, rewriter);
371 return mlir::success();
372 }
373};
374
375struct DotProductOpConversion
376 : public HlfirIntrinsicConversion<hlfir::DotProductOp> {
377 using HlfirIntrinsicConversion<hlfir::DotProductOp>::HlfirIntrinsicConversion;
378
379 mlir::LogicalResult
380 matchAndRewrite(hlfir::DotProductOp dotProduct,
381 mlir::PatternRewriter &rewriter) const override {
382 fir::FirOpBuilder builder{rewriter, dotProduct.getOperation()};
383 const mlir::Location &loc = dotProduct->getLoc();
384
385 mlir::Value lhs = dotProduct.getLhs();
386 mlir::Value rhs = dotProduct.getRhs();
387 llvm::SmallVector<IntrinsicArgument, 2> inArgs;
388 inArgs.push_back({lhs, lhs.getType()});
389 inArgs.push_back({rhs, rhs.getType()});
390
391 auto *argLowering = fir::getIntrinsicArgumentLowering("dot_product");
392 llvm::SmallVector<fir::ExtendedValue, 2> args =
393 lowerArguments(dotProduct, inArgs, rewriter, argLowering);
394
395 mlir::Type scalarResultType =
396 hlfir::getFortranElementType(dotProduct.getType());
397
398 auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
399 builder, loc, "dot_product", scalarResultType, args);
400
401 processReturnValue(dotProduct, resultExv, mustBeFreed, builder, rewriter);
402 return mlir::success();
403 }
404};
405
406class TransposeOpConversion
407 : public HlfirIntrinsicConversion<hlfir::TransposeOp> {
408 using HlfirIntrinsicConversion<hlfir::TransposeOp>::HlfirIntrinsicConversion;
409
410 mlir::LogicalResult
411 matchAndRewrite(hlfir::TransposeOp transpose,
412 mlir::PatternRewriter &rewriter) const override {
413 fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
414 const mlir::Location &loc = transpose->getLoc();
415
416 mlir::Value arg = transpose.getArray();
417 llvm::SmallVector<IntrinsicArgument, 1> inArgs;
418 inArgs.push_back({arg, arg.getType()});
419
420 auto *argLowering = fir::getIntrinsicArgumentLowering("transpose");
421 llvm::SmallVector<fir::ExtendedValue, 1> args =
422 lowerArguments(transpose, inArgs, rewriter, argLowering);
423
424 mlir::Type scalarResultType =
425 hlfir::getFortranElementType(transpose.getType());
426
427 auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
428 builder, loc, "transpose", scalarResultType, args);
429
430 processReturnValue(transpose, resultExv, mustBeFreed, builder, rewriter);
431 return mlir::success();
432 }
433};
434
435struct MatmulTransposeOpConversion
436 : public HlfirIntrinsicConversion<hlfir::MatmulTransposeOp> {
437 using HlfirIntrinsicConversion<
438 hlfir::MatmulTransposeOp>::HlfirIntrinsicConversion;
439
440 mlir::LogicalResult
441 matchAndRewrite(hlfir::MatmulTransposeOp multranspose,
442 mlir::PatternRewriter &rewriter) const override {
443 fir::FirOpBuilder builder{rewriter, multranspose.getOperation()};
444 const mlir::Location &loc = multranspose->getLoc();
445
446 mlir::Value lhs = multranspose.getLhs();
447 mlir::Value rhs = multranspose.getRhs();
448 llvm::SmallVector<IntrinsicArgument, 2> inArgs;
449 inArgs.push_back({lhs, lhs.getType()});
450 inArgs.push_back({rhs, rhs.getType()});
451
452 auto *argLowering = fir::getIntrinsicArgumentLowering("matmul");
453 llvm::SmallVector<fir::ExtendedValue, 2> args =
454 lowerArguments(multranspose, inArgs, rewriter, argLowering);
455
456 mlir::Type scalarResultType =
457 hlfir::getFortranElementType(multranspose.getType());
458
459 auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
460 builder, loc, "matmul_transpose", scalarResultType, args);
461
462 processReturnValue(multranspose, resultExv, mustBeFreed, builder, rewriter);
463 return mlir::success();
464 }
465};
466
467class LowerHLFIRIntrinsics
468 : public hlfir::impl::LowerHLFIRIntrinsicsBase<LowerHLFIRIntrinsics> {
469public:
470 void runOnOperation() override {
471 // TODO: make this a pass operating on FuncOp. The issue is that
472 // FirOpBuilder helpers may generate new FuncOp because of runtime/llvm
473 // intrinsics calls creation. This may create race conflict if the pass is
474 // scheduled on FuncOp. A solution could be to provide an optional mutex
475 // when building a FirOpBuilder and locking around FuncOp and GlobalOp
476 // creation, but this needs a bit more thinking, so at this point the pass
477 // is scheduled on the moduleOp.
478 mlir::ModuleOp module = this->getOperation();
479 mlir::MLIRContext *context = &getContext();
480 mlir::RewritePatternSet patterns(context);
481 patterns
482 .insert<MatmulOpConversion, MatmulTransposeOpConversion,
483 AllOpConversion, AnyOpConversion, SumOpConversion,
484 ProductOpConversion, TransposeOpConversion, CountOpConversion,
485 DotProductOpConversion, MaxvalOpConversion, MinvalOpConversion,
486 MinlocOpConversion, MaxlocOpConversion>(context);
487
488 // While conceptually this pass is performing dialect conversion, we use
489 // pattern rewrites here instead of dialect conversion because this pass
490 // looses array bounds from some of the expressions e.g.
491 // !hlfir.expr<2xi32> -> !hlfir.expr<?xi32>
492 // MLIR thinks this is a different type so dialect conversion fails.
493 // Pattern rewriting only requires that the resulting IR is still valid
494 mlir::GreedyRewriteConfig config;
495 // Prevent the pattern driver from merging blocks
496 config.enableRegionSimplification = false;
497
498 if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
499 module, std::move(patterns), config))) {
500 mlir::emitError(mlir::UnknownLoc::get(context),
501 "failure in HLFIR intrinsic lowering");
502 signalPassFailure();
503 }
504 }
505};
506} // namespace
507
508std::unique_ptr<mlir::Pass> hlfir::createLowerHLFIRIntrinsicsPass() {
509 return std::make_unique<LowerHLFIRIntrinsics>();
510}
511

source code of flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp