1 | //===- NVVMAttachTarget.cpp - Attach an NVVM target -----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the `GpuNVVMAttachTarget` pass, attaching `#nvvm.target` |
10 | // attributes to GPU modules. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
15 | |
16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
17 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
18 | #include "mlir/IR/Builders.h" |
19 | #include "mlir/Pass/Pass.h" |
20 | #include "mlir/Target/LLVM/NVVM/Target.h" |
21 | #include "llvm/Support/Regex.h" |
22 | |
23 | namespace mlir { |
24 | #define GEN_PASS_DEF_GPUNVVMATTACHTARGET |
25 | #include "mlir/Dialect/GPU/Transforms/Passes.h.inc" |
26 | } // namespace mlir |
27 | |
28 | using namespace mlir; |
29 | using namespace mlir::NVVM; |
30 | |
31 | namespace { |
32 | struct NVVMAttachTarget |
33 | : public impl::GpuNVVMAttachTargetBase<NVVMAttachTarget> { |
34 | using Base::Base; |
35 | |
36 | DictionaryAttr getFlags(OpBuilder &builder) const; |
37 | |
38 | void runOnOperation() override; |
39 | |
40 | void getDependentDialects(DialectRegistry ®istry) const override { |
41 | registry.insert<NVVM::NVVMDialect>(); |
42 | } |
43 | }; |
44 | } // namespace |
45 | |
46 | DictionaryAttr NVVMAttachTarget::getFlags(OpBuilder &builder) const { |
47 | UnitAttr unitAttr = builder.getUnitAttr(); |
48 | SmallVector<NamedAttribute, 2> flags; |
49 | auto addFlag = [&](StringRef flag) { |
50 | flags.push_back(Elt: builder.getNamedAttr(name: flag, val: unitAttr)); |
51 | }; |
52 | if (fastFlag) |
53 | addFlag("fast" ); |
54 | if (ftzFlag) |
55 | addFlag("ftz" ); |
56 | if (!flags.empty()) |
57 | return builder.getDictionaryAttr(flags); |
58 | return nullptr; |
59 | } |
60 | |
61 | void NVVMAttachTarget::runOnOperation() { |
62 | OpBuilder builder(&getContext()); |
63 | ArrayRef<std::string> libs(linkLibs); |
64 | SmallVector<StringRef> filesToLink(libs.begin(), libs.end()); |
65 | auto target = builder.getAttr<NVVMTargetAttr>( |
66 | optLevel, triple, chip, features, getFlags(builder), |
67 | filesToLink.empty() ? nullptr : builder.getStrArrayAttr(filesToLink)); |
68 | llvm::Regex matcher(moduleMatcher); |
69 | for (Region ®ion : getOperation()->getRegions()) |
70 | for (Block &block : region.getBlocks()) |
71 | for (auto module : block.getOps<gpu::GPUModuleOp>()) { |
72 | // Check if the name of the module matches. |
73 | if (!moduleMatcher.empty() && !matcher.match(module.getName())) |
74 | continue; |
75 | // Create the target array. |
76 | SmallVector<Attribute> targets; |
77 | if (std::optional<ArrayAttr> attrs = module.getTargets()) |
78 | targets.append(attrs->getValue().begin(), attrs->getValue().end()); |
79 | targets.push_back(target); |
80 | // Remove any duplicate targets. |
81 | targets.erase(std::unique(targets.begin(), targets.end()), |
82 | targets.end()); |
83 | // Update the target attribute array. |
84 | module.setTargetsAttr(builder.getArrayAttr(targets)); |
85 | } |
86 | } |
87 | |