1 | //===- LowerHLFIRIntrinsics.cpp - Transformational intrinsics to FIR ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
10 | #include "flang/Optimizer/Builder/HLFIRTools.h" |
11 | #include "flang/Optimizer/Builder/IntrinsicCall.h" |
12 | #include "flang/Optimizer/Builder/Todo.h" |
13 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
14 | #include "flang/Optimizer/Dialect/FIROps.h" |
15 | #include "flang/Optimizer/Dialect/FIRType.h" |
16 | #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
17 | #include "flang/Optimizer/HLFIR/HLFIRDialect.h" |
18 | #include "flang/Optimizer/HLFIR/HLFIROps.h" |
19 | #include "flang/Optimizer/HLFIR/Passes.h" |
20 | #include "mlir/IR/BuiltinDialect.h" |
21 | #include "mlir/IR/MLIRContext.h" |
22 | #include "mlir/IR/PatternMatch.h" |
23 | #include "mlir/Pass/Pass.h" |
24 | #include "mlir/Pass/PassManager.h" |
25 | #include "mlir/Support/LogicalResult.h" |
26 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
27 | #include <optional> |
28 | |
29 | namespace hlfir { |
30 | #define GEN_PASS_DEF_LOWERHLFIRINTRINSICS |
31 | #include "flang/Optimizer/HLFIR/Passes.h.inc" |
32 | } // namespace hlfir |
33 | |
34 | namespace { |
35 | |
36 | /// Base class for passes converting transformational intrinsic operations into |
37 | /// runtime calls |
38 | template <class OP> |
39 | class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> { |
40 | public: |
41 | explicit HlfirIntrinsicConversion(mlir::MLIRContext *ctx) |
42 | : mlir::OpRewritePattern<OP>{ctx} { |
43 | // required for cases where intrinsics are chained together e.g. |
44 | // matmul(matmul(a, b), c) |
45 | // because converting the inner operation then invalidates the |
46 | // outer operation: causing the pattern to apply recursively. |
47 | // |
48 | // This is safe because we always progress with each iteration. Circular |
49 | // applications of operations are not expressible in MLIR because we use |
50 | // an SSA form and one must become first. E.g. |
51 | // %a = hlfir.matmul %b %d |
52 | // %b = hlfir.matmul %a %d |
53 | // cannot be written. |
54 | // MSVC needs the this-> |
55 | this->setHasBoundedRewriteRecursion(true); |
56 | } |
57 | |
58 | protected: |
59 | struct IntrinsicArgument { |
60 | mlir::Value val; // allowed to be null if the argument is absent |
61 | mlir::Type desiredType; |
62 | }; |
63 | |
64 | /// Lower the arguments to the intrinsic: adding necessary boxing and |
65 | /// conversion to match the signature of the intrinsic in the runtime library. |
66 | llvm::SmallVector<fir::ExtendedValue, 3> |
67 | lowerArguments(mlir::Operation *op, |
68 | const llvm::ArrayRef<IntrinsicArgument> &args, |
69 | mlir::PatternRewriter &rewriter, |
70 | const fir::IntrinsicArgumentLoweringRules *argLowering) const { |
71 | mlir::Location loc = op->getLoc(); |
72 | fir::FirOpBuilder builder{rewriter, op}; |
73 | |
74 | llvm::SmallVector<fir::ExtendedValue, 3> ret; |
75 | llvm::SmallVector<std::function<void()>, 2> cleanupFns; |
76 | |
77 | for (size_t i = 0; i < args.size(); ++i) { |
78 | mlir::Value arg = args[i].val; |
79 | mlir::Type desiredType = args[i].desiredType; |
80 | if (!arg) { |
81 | ret.emplace_back(fir::getAbsentIntrinsicArgument()); |
82 | continue; |
83 | } |
84 | hlfir::Entity entity{arg}; |
85 | |
86 | fir::ArgLoweringRule argRules = |
87 | fir::lowerIntrinsicArgumentAs(*argLowering, i); |
88 | switch (argRules.lowerAs) { |
89 | case fir::LowerIntrinsicArgAs::Value: { |
90 | if (args[i].desiredType != arg.getType()) { |
91 | arg = builder.createConvert(loc, desiredType, arg); |
92 | entity = hlfir::Entity{arg}; |
93 | } |
94 | auto [exv, cleanup] = hlfir::convertToValue(loc, builder, entity); |
95 | if (cleanup) |
96 | cleanupFns.push_back(*cleanup); |
97 | ret.emplace_back(exv); |
98 | } break; |
99 | case fir::LowerIntrinsicArgAs::Addr: { |
100 | auto [exv, cleanup] = |
101 | hlfir::convertToAddress(loc, builder, entity, desiredType); |
102 | if (cleanup) |
103 | cleanupFns.push_back(*cleanup); |
104 | ret.emplace_back(exv); |
105 | } break; |
106 | case fir::LowerIntrinsicArgAs::Box: { |
107 | auto [box, cleanup] = |
108 | hlfir::convertToBox(loc, builder, entity, desiredType); |
109 | if (cleanup) |
110 | cleanupFns.push_back(*cleanup); |
111 | ret.emplace_back(box); |
112 | } break; |
113 | case fir::LowerIntrinsicArgAs::Inquired: { |
114 | if (args[i].desiredType != arg.getType()) { |
115 | arg = builder.createConvert(loc, desiredType, arg); |
116 | entity = hlfir::Entity{arg}; |
117 | } |
118 | // Place hlfir.expr in memory, and unbox fir.boxchar. Other entities |
119 | // are translated to fir::ExtendedValue without transofrmation (notably, |
120 | // pointers/allocatable are not dereferenced). |
121 | // TODO: once lowering to FIR retires, UBOUND and LBOUND can be |
122 | // simplified since the fir.box lowered here are now guarenteed to |
123 | // contain the local lower bounds thanks to the hlfir.declare (the extra |
124 | // rebox can be removed). |
125 | auto [exv, cleanup] = |
126 | hlfir::translateToExtendedValue(loc, builder, entity); |
127 | if (cleanup) |
128 | cleanupFns.push_back(*cleanup); |
129 | ret.emplace_back(exv); |
130 | } break; |
131 | } |
132 | } |
133 | |
134 | if (cleanupFns.size()) { |
135 | auto oldInsertionPoint = builder.saveInsertionPoint(); |
136 | builder.setInsertionPointAfter(op); |
137 | for (std::function<void()> cleanup : cleanupFns) |
138 | cleanup(); |
139 | builder.restoreInsertionPoint(oldInsertionPoint); |
140 | } |
141 | |
142 | return ret; |
143 | } |
144 | |
145 | void processReturnValue(mlir::Operation *op, |
146 | const fir::ExtendedValue &resultExv, bool mustBeFreed, |
147 | fir::FirOpBuilder &builder, |
148 | mlir::PatternRewriter &rewriter) const { |
149 | mlir::Location loc = op->getLoc(); |
150 | |
151 | mlir::Value firBase = fir::getBase(resultExv); |
152 | mlir::Type firBaseTy = firBase.getType(); |
153 | |
154 | std::optional<hlfir::EntityWithAttributes> resultEntity; |
155 | if (fir::isa_trivial(firBaseTy)) { |
156 | // Some intrinsics return i1 when the original operation |
157 | // produces fir.logical<>, so we may need to cast it. |
158 | firBase = builder.createConvert(loc, op->getResult(0).getType(), firBase); |
159 | resultEntity = hlfir::EntityWithAttributes{firBase}; |
160 | } else { |
161 | resultEntity = |
162 | hlfir::genDeclare(loc, builder, resultExv, ".tmp.intrinsic_result" , |
163 | fir::FortranVariableFlagsAttr{}); |
164 | } |
165 | |
166 | if (resultEntity->isVariable()) { |
167 | hlfir::AsExprOp asExpr = builder.create<hlfir::AsExprOp>( |
168 | loc, *resultEntity, builder.createBool(loc, mustBeFreed)); |
169 | resultEntity = hlfir::EntityWithAttributes{asExpr.getResult()}; |
170 | } |
171 | |
172 | mlir::Value base = resultEntity->getBase(); |
173 | if (!mlir::isa<hlfir::ExprType>(base.getType())) { |
174 | for (mlir::Operation *use : op->getResult(0).getUsers()) { |
175 | if (mlir::isa<hlfir::DestroyOp>(use)) |
176 | rewriter.eraseOp(use); |
177 | } |
178 | } |
179 | |
180 | rewriter.replaceOp(op, base); |
181 | } |
182 | }; |
183 | |
184 | // Given an integer or array of integer type, calculate the Kind parameter from |
185 | // the width for use in runtime intrinsic calls. |
186 | static unsigned getKindForType(mlir::Type ty) { |
187 | mlir::Type eltty = hlfir::getFortranElementType(ty); |
188 | unsigned width = eltty.cast<mlir::IntegerType>().getWidth(); |
189 | return width / 8; |
190 | } |
191 | |
192 | template <class OP> |
193 | class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> { |
194 | using HlfirIntrinsicConversion<OP>::HlfirIntrinsicConversion; |
195 | using IntrinsicArgument = |
196 | typename HlfirIntrinsicConversion<OP>::IntrinsicArgument; |
197 | using HlfirIntrinsicConversion<OP>::lowerArguments; |
198 | using HlfirIntrinsicConversion<OP>::processReturnValue; |
199 | |
200 | protected: |
201 | auto buildNumericalArgs(OP operation, mlir::Type i32, mlir::Type logicalType, |
202 | mlir::PatternRewriter &rewriter, |
203 | std::string opName) const { |
204 | llvm::SmallVector<IntrinsicArgument, 3> inArgs; |
205 | inArgs.push_back({operation.getArray(), operation.getArray().getType()}); |
206 | inArgs.push_back({operation.getDim(), i32}); |
207 | inArgs.push_back({operation.getMask(), logicalType}); |
208 | auto *argLowering = fir::getIntrinsicArgumentLowering(opName); |
209 | return lowerArguments(operation, inArgs, rewriter, argLowering); |
210 | }; |
211 | |
212 | auto buildMinMaxLocArgs(OP operation, mlir::Type i32, mlir::Type logicalType, |
213 | mlir::PatternRewriter &rewriter, std::string opName, |
214 | fir::FirOpBuilder builder) const { |
215 | llvm::SmallVector<IntrinsicArgument, 3> inArgs; |
216 | inArgs.push_back({operation.getArray(), operation.getArray().getType()}); |
217 | inArgs.push_back({operation.getDim(), i32}); |
218 | inArgs.push_back({operation.getMask(), logicalType}); |
219 | mlir::Value kind = builder.createIntegerConstant( |
220 | operation->getLoc(), i32, getKindForType(operation.getType())); |
221 | inArgs.push_back({kind, i32}); |
222 | inArgs.push_back({operation.getBack(), i32}); |
223 | auto *argLowering = fir::getIntrinsicArgumentLowering(opName); |
224 | return lowerArguments(operation, inArgs, rewriter, argLowering); |
225 | }; |
226 | |
227 | auto buildLogicalArgs(OP operation, mlir::Type i32, mlir::Type logicalType, |
228 | mlir::PatternRewriter &rewriter, |
229 | std::string opName) const { |
230 | llvm::SmallVector<IntrinsicArgument, 2> inArgs; |
231 | inArgs.push_back({operation.getMask(), logicalType}); |
232 | inArgs.push_back({operation.getDim(), i32}); |
233 | auto *argLowering = fir::getIntrinsicArgumentLowering(opName); |
234 | return lowerArguments(operation, inArgs, rewriter, argLowering); |
235 | }; |
236 | |
237 | public: |
238 | mlir::LogicalResult |
239 | matchAndRewrite(OP operation, |
240 | mlir::PatternRewriter &rewriter) const override { |
241 | std::string opName; |
242 | if constexpr (std::is_same_v<OP, hlfir::SumOp>) { |
243 | opName = "sum" ; |
244 | } else if constexpr (std::is_same_v<OP, hlfir::ProductOp>) { |
245 | opName = "product" ; |
246 | } else if constexpr (std::is_same_v<OP, hlfir::MaxvalOp>) { |
247 | opName = "maxval" ; |
248 | } else if constexpr (std::is_same_v<OP, hlfir::MinvalOp>) { |
249 | opName = "minval" ; |
250 | } else if constexpr (std::is_same_v<OP, hlfir::MinlocOp>) { |
251 | opName = "minloc" ; |
252 | } else if constexpr (std::is_same_v<OP, hlfir::MaxlocOp>) { |
253 | opName = "maxloc" ; |
254 | } else if constexpr (std::is_same_v<OP, hlfir::AnyOp>) { |
255 | opName = "any" ; |
256 | } else if constexpr (std::is_same_v<OP, hlfir::AllOp>) { |
257 | opName = "all" ; |
258 | } else { |
259 | return mlir::failure(); |
260 | } |
261 | |
262 | fir::FirOpBuilder builder{rewriter, operation.getOperation()}; |
263 | const mlir::Location &loc = operation->getLoc(); |
264 | |
265 | mlir::Type i32 = builder.getI32Type(); |
266 | mlir::Type logicalType = fir::LogicalType::get( |
267 | builder.getContext(), builder.getKindMap().defaultLogicalKind()); |
268 | |
269 | llvm::SmallVector<fir::ExtendedValue, 0> args; |
270 | |
271 | if constexpr (std::is_same_v<OP, hlfir::SumOp> || |
272 | std::is_same_v<OP, hlfir::ProductOp> || |
273 | std::is_same_v<OP, hlfir::MaxvalOp> || |
274 | std::is_same_v<OP, hlfir::MinvalOp>) { |
275 | args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName); |
276 | } else if constexpr (std::is_same_v<OP, hlfir::MinlocOp> || |
277 | std::is_same_v<OP, hlfir::MaxlocOp>) { |
278 | args = buildMinMaxLocArgs(operation, i32, logicalType, rewriter, opName, |
279 | builder); |
280 | } else { |
281 | args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName); |
282 | } |
283 | |
284 | mlir::Type scalarResultType = |
285 | hlfir::getFortranElementType(operation.getType()); |
286 | |
287 | auto [resultExv, mustBeFreed] = |
288 | fir::genIntrinsicCall(builder, loc, opName, scalarResultType, args); |
289 | |
290 | processReturnValue(operation, resultExv, mustBeFreed, builder, rewriter); |
291 | return mlir::success(); |
292 | } |
293 | }; |
294 | |
295 | using SumOpConversion = HlfirReductionIntrinsicConversion<hlfir::SumOp>; |
296 | |
297 | using ProductOpConversion = HlfirReductionIntrinsicConversion<hlfir::ProductOp>; |
298 | |
299 | using MaxvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxvalOp>; |
300 | |
301 | using MinvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinvalOp>; |
302 | |
303 | using MinlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinlocOp>; |
304 | |
305 | using MaxlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxlocOp>; |
306 | |
307 | using AnyOpConversion = HlfirReductionIntrinsicConversion<hlfir::AnyOp>; |
308 | |
309 | using AllOpConversion = HlfirReductionIntrinsicConversion<hlfir::AllOp>; |
310 | |
311 | struct CountOpConversion : public HlfirIntrinsicConversion<hlfir::CountOp> { |
312 | using HlfirIntrinsicConversion<hlfir::CountOp>::HlfirIntrinsicConversion; |
313 | |
314 | mlir::LogicalResult |
315 | matchAndRewrite(hlfir::CountOp count, |
316 | mlir::PatternRewriter &rewriter) const override { |
317 | fir::FirOpBuilder builder{rewriter, count.getOperation()}; |
318 | const mlir::Location &loc = count->getLoc(); |
319 | |
320 | mlir::Type i32 = builder.getI32Type(); |
321 | mlir::Type logicalType = fir::LogicalType::get( |
322 | builder.getContext(), builder.getKindMap().defaultLogicalKind()); |
323 | |
324 | llvm::SmallVector<IntrinsicArgument, 3> inArgs; |
325 | inArgs.push_back({count.getMask(), logicalType}); |
326 | inArgs.push_back({count.getDim(), i32}); |
327 | mlir::Value kind = builder.createIntegerConstant( |
328 | count->getLoc(), i32, getKindForType(count.getType())); |
329 | inArgs.push_back({kind, i32}); |
330 | |
331 | auto *argLowering = fir::getIntrinsicArgumentLowering("count" ); |
332 | llvm::SmallVector<fir::ExtendedValue, 3> args = |
333 | lowerArguments(count, inArgs, rewriter, argLowering); |
334 | |
335 | mlir::Type scalarResultType = hlfir::getFortranElementType(count.getType()); |
336 | |
337 | auto [resultExv, mustBeFreed] = |
338 | fir::genIntrinsicCall(builder, loc, "count" , scalarResultType, args); |
339 | |
340 | processReturnValue(count, resultExv, mustBeFreed, builder, rewriter); |
341 | return mlir::success(); |
342 | } |
343 | }; |
344 | |
345 | struct MatmulOpConversion : public HlfirIntrinsicConversion<hlfir::MatmulOp> { |
346 | using HlfirIntrinsicConversion<hlfir::MatmulOp>::HlfirIntrinsicConversion; |
347 | |
348 | mlir::LogicalResult |
349 | matchAndRewrite(hlfir::MatmulOp matmul, |
350 | mlir::PatternRewriter &rewriter) const override { |
351 | fir::FirOpBuilder builder{rewriter, matmul.getOperation()}; |
352 | const mlir::Location &loc = matmul->getLoc(); |
353 | |
354 | mlir::Value lhs = matmul.getLhs(); |
355 | mlir::Value rhs = matmul.getRhs(); |
356 | llvm::SmallVector<IntrinsicArgument, 2> inArgs; |
357 | inArgs.push_back({lhs, lhs.getType()}); |
358 | inArgs.push_back({rhs, rhs.getType()}); |
359 | |
360 | auto *argLowering = fir::getIntrinsicArgumentLowering("matmul" ); |
361 | llvm::SmallVector<fir::ExtendedValue, 2> args = |
362 | lowerArguments(matmul, inArgs, rewriter, argLowering); |
363 | |
364 | mlir::Type scalarResultType = |
365 | hlfir::getFortranElementType(matmul.getType()); |
366 | |
367 | auto [resultExv, mustBeFreed] = |
368 | fir::genIntrinsicCall(builder, loc, "matmul" , scalarResultType, args); |
369 | |
370 | processReturnValue(matmul, resultExv, mustBeFreed, builder, rewriter); |
371 | return mlir::success(); |
372 | } |
373 | }; |
374 | |
375 | struct DotProductOpConversion |
376 | : public HlfirIntrinsicConversion<hlfir::DotProductOp> { |
377 | using HlfirIntrinsicConversion<hlfir::DotProductOp>::HlfirIntrinsicConversion; |
378 | |
379 | mlir::LogicalResult |
380 | matchAndRewrite(hlfir::DotProductOp dotProduct, |
381 | mlir::PatternRewriter &rewriter) const override { |
382 | fir::FirOpBuilder builder{rewriter, dotProduct.getOperation()}; |
383 | const mlir::Location &loc = dotProduct->getLoc(); |
384 | |
385 | mlir::Value lhs = dotProduct.getLhs(); |
386 | mlir::Value rhs = dotProduct.getRhs(); |
387 | llvm::SmallVector<IntrinsicArgument, 2> inArgs; |
388 | inArgs.push_back({lhs, lhs.getType()}); |
389 | inArgs.push_back({rhs, rhs.getType()}); |
390 | |
391 | auto *argLowering = fir::getIntrinsicArgumentLowering("dot_product" ); |
392 | llvm::SmallVector<fir::ExtendedValue, 2> args = |
393 | lowerArguments(dotProduct, inArgs, rewriter, argLowering); |
394 | |
395 | mlir::Type scalarResultType = |
396 | hlfir::getFortranElementType(dotProduct.getType()); |
397 | |
398 | auto [resultExv, mustBeFreed] = fir::genIntrinsicCall( |
399 | builder, loc, "dot_product" , scalarResultType, args); |
400 | |
401 | processReturnValue(dotProduct, resultExv, mustBeFreed, builder, rewriter); |
402 | return mlir::success(); |
403 | } |
404 | }; |
405 | |
406 | class TransposeOpConversion |
407 | : public HlfirIntrinsicConversion<hlfir::TransposeOp> { |
408 | using HlfirIntrinsicConversion<hlfir::TransposeOp>::HlfirIntrinsicConversion; |
409 | |
410 | mlir::LogicalResult |
411 | matchAndRewrite(hlfir::TransposeOp transpose, |
412 | mlir::PatternRewriter &rewriter) const override { |
413 | fir::FirOpBuilder builder{rewriter, transpose.getOperation()}; |
414 | const mlir::Location &loc = transpose->getLoc(); |
415 | |
416 | mlir::Value arg = transpose.getArray(); |
417 | llvm::SmallVector<IntrinsicArgument, 1> inArgs; |
418 | inArgs.push_back({arg, arg.getType()}); |
419 | |
420 | auto *argLowering = fir::getIntrinsicArgumentLowering("transpose" ); |
421 | llvm::SmallVector<fir::ExtendedValue, 1> args = |
422 | lowerArguments(transpose, inArgs, rewriter, argLowering); |
423 | |
424 | mlir::Type scalarResultType = |
425 | hlfir::getFortranElementType(transpose.getType()); |
426 | |
427 | auto [resultExv, mustBeFreed] = fir::genIntrinsicCall( |
428 | builder, loc, "transpose" , scalarResultType, args); |
429 | |
430 | processReturnValue(transpose, resultExv, mustBeFreed, builder, rewriter); |
431 | return mlir::success(); |
432 | } |
433 | }; |
434 | |
435 | struct MatmulTransposeOpConversion |
436 | : public HlfirIntrinsicConversion<hlfir::MatmulTransposeOp> { |
437 | using HlfirIntrinsicConversion< |
438 | hlfir::MatmulTransposeOp>::HlfirIntrinsicConversion; |
439 | |
440 | mlir::LogicalResult |
441 | matchAndRewrite(hlfir::MatmulTransposeOp multranspose, |
442 | mlir::PatternRewriter &rewriter) const override { |
443 | fir::FirOpBuilder builder{rewriter, multranspose.getOperation()}; |
444 | const mlir::Location &loc = multranspose->getLoc(); |
445 | |
446 | mlir::Value lhs = multranspose.getLhs(); |
447 | mlir::Value rhs = multranspose.getRhs(); |
448 | llvm::SmallVector<IntrinsicArgument, 2> inArgs; |
449 | inArgs.push_back({lhs, lhs.getType()}); |
450 | inArgs.push_back({rhs, rhs.getType()}); |
451 | |
452 | auto *argLowering = fir::getIntrinsicArgumentLowering("matmul" ); |
453 | llvm::SmallVector<fir::ExtendedValue, 2> args = |
454 | lowerArguments(multranspose, inArgs, rewriter, argLowering); |
455 | |
456 | mlir::Type scalarResultType = |
457 | hlfir::getFortranElementType(multranspose.getType()); |
458 | |
459 | auto [resultExv, mustBeFreed] = fir::genIntrinsicCall( |
460 | builder, loc, "matmul_transpose" , scalarResultType, args); |
461 | |
462 | processReturnValue(multranspose, resultExv, mustBeFreed, builder, rewriter); |
463 | return mlir::success(); |
464 | } |
465 | }; |
466 | |
467 | class LowerHLFIRIntrinsics |
468 | : public hlfir::impl::LowerHLFIRIntrinsicsBase<LowerHLFIRIntrinsics> { |
469 | public: |
470 | void runOnOperation() override { |
471 | // TODO: make this a pass operating on FuncOp. The issue is that |
472 | // FirOpBuilder helpers may generate new FuncOp because of runtime/llvm |
473 | // intrinsics calls creation. This may create race conflict if the pass is |
474 | // scheduled on FuncOp. A solution could be to provide an optional mutex |
475 | // when building a FirOpBuilder and locking around FuncOp and GlobalOp |
476 | // creation, but this needs a bit more thinking, so at this point the pass |
477 | // is scheduled on the moduleOp. |
478 | mlir::ModuleOp module = this->getOperation(); |
479 | mlir::MLIRContext *context = &getContext(); |
480 | mlir::RewritePatternSet patterns(context); |
481 | patterns |
482 | .insert<MatmulOpConversion, MatmulTransposeOpConversion, |
483 | AllOpConversion, AnyOpConversion, SumOpConversion, |
484 | ProductOpConversion, TransposeOpConversion, CountOpConversion, |
485 | DotProductOpConversion, MaxvalOpConversion, MinvalOpConversion, |
486 | MinlocOpConversion, MaxlocOpConversion>(context); |
487 | |
488 | // While conceptually this pass is performing dialect conversion, we use |
489 | // pattern rewrites here instead of dialect conversion because this pass |
490 | // looses array bounds from some of the expressions e.g. |
491 | // !hlfir.expr<2xi32> -> !hlfir.expr<?xi32> |
492 | // MLIR thinks this is a different type so dialect conversion fails. |
493 | // Pattern rewriting only requires that the resulting IR is still valid |
494 | mlir::GreedyRewriteConfig config; |
495 | // Prevent the pattern driver from merging blocks |
496 | config.enableRegionSimplification = false; |
497 | |
498 | if (mlir::failed(mlir::applyPatternsAndFoldGreedily( |
499 | module, std::move(patterns), config))) { |
500 | mlir::emitError(mlir::UnknownLoc::get(context), |
501 | "failure in HLFIR intrinsic lowering" ); |
502 | signalPassFailure(); |
503 | } |
504 | } |
505 | }; |
506 | } // namespace |
507 | |
508 | std::unique_ptr<mlir::Pass> hlfir::createLowerHLFIRIntrinsicsPass() { |
509 | return std::make_unique<LowerHLFIRIntrinsics>(); |
510 | } |
511 | |