1 | //===- ROCDLAttachTarget.cpp - Attach an ROCDL target ---------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the `GpuROCDLAttachTarget` pass, attaching |
10 | // `#rocdl.target` attributes to GPU modules. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
15 | |
16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
17 | #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
18 | #include "mlir/IR/Builders.h" |
19 | #include "mlir/Pass/Pass.h" |
20 | #include "mlir/Target/LLVM/ROCDL/Target.h" |
21 | #include "llvm/Support/Regex.h" |
22 | |
23 | namespace mlir { |
24 | #define GEN_PASS_DEF_GPUROCDLATTACHTARGET |
25 | #include "mlir/Dialect/GPU/Transforms/Passes.h.inc" |
26 | } // namespace mlir |
27 | |
28 | using namespace mlir; |
29 | using namespace mlir::ROCDL; |
30 | |
31 | namespace { |
32 | struct ROCDLAttachTarget |
33 | : public impl::GpuROCDLAttachTargetBase<ROCDLAttachTarget> { |
34 | using Base::Base; |
35 | |
36 | DictionaryAttr getFlags(OpBuilder &builder) const; |
37 | |
38 | void runOnOperation() override; |
39 | |
40 | void getDependentDialects(DialectRegistry ®istry) const override { |
41 | registry.insert<ROCDL::ROCDLDialect>(); |
42 | } |
43 | }; |
44 | } // namespace |
45 | |
46 | DictionaryAttr ROCDLAttachTarget::getFlags(OpBuilder &builder) const { |
47 | UnitAttr unitAttr = builder.getUnitAttr(); |
48 | SmallVector<NamedAttribute, 6> flags; |
49 | auto addFlag = [&](StringRef flag) { |
50 | flags.push_back(Elt: builder.getNamedAttr(name: flag, val: unitAttr)); |
51 | }; |
52 | if (!wave64Flag) |
53 | addFlag("no_wave64" ); |
54 | if (fastFlag) |
55 | addFlag("fast" ); |
56 | if (dazFlag) |
57 | addFlag("daz" ); |
58 | if (finiteOnlyFlag) |
59 | addFlag("finite_only" ); |
60 | if (unsafeMathFlag) |
61 | addFlag("unsafe_math" ); |
62 | if (!correctSqrtFlag) |
63 | addFlag("unsafe_sqrt" ); |
64 | if (!flags.empty()) |
65 | return builder.getDictionaryAttr(flags); |
66 | return nullptr; |
67 | } |
68 | |
69 | void ROCDLAttachTarget::runOnOperation() { |
70 | OpBuilder builder(&getContext()); |
71 | ArrayRef<std::string> libs(linkLibs); |
72 | SmallVector<StringRef> filesToLink(libs.begin(), libs.end()); |
73 | auto target = builder.getAttr<ROCDLTargetAttr>( |
74 | optLevel, triple, chip, features, abiVersion, getFlags(builder), |
75 | filesToLink.empty() ? nullptr : builder.getStrArrayAttr(filesToLink)); |
76 | llvm::Regex matcher(moduleMatcher); |
77 | for (Region ®ion : getOperation()->getRegions()) |
78 | for (Block &block : region.getBlocks()) |
79 | for (auto module : block.getOps<gpu::GPUModuleOp>()) { |
80 | // Check if the name of the module matches. |
81 | if (!moduleMatcher.empty() && !matcher.match(module.getName())) |
82 | continue; |
83 | // Create the target array. |
84 | SmallVector<Attribute> targets; |
85 | if (std::optional<ArrayAttr> attrs = module.getTargets()) |
86 | targets.append(attrs->getValue().begin(), attrs->getValue().end()); |
87 | targets.push_back(target); |
88 | // Remove any duplicate targets. |
89 | targets.erase(std::unique(targets.begin(), targets.end()), |
90 | targets.end()); |
91 | // Update the target attribute array. |
92 | module.setTargetsAttr(builder.getArrayAttr(targets)); |
93 | } |
94 | } |
95 | |