1//===- LowerNontemporal.cpp -------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Add nontemporal attributes to load and stores of variables marked as
10// nontemporal.
11//
12//===----------------------------------------------------------------------===//
13
14#include "flang/Optimizer/Dialect/FIRCG/CGOps.h"
15#include "flang/Optimizer/Dialect/FIROpsSupport.h"
16#include "flang/Optimizer/OpenMP/Passes.h"
17#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
18#include "llvm/ADT/TypeSwitch.h"
19
20using namespace mlir;
21
22namespace flangomp {
23#define GEN_PASS_DEF_LOWERNONTEMPORALPASS
24#include "flang/Optimizer/OpenMP/Passes.h.inc"
25} // namespace flangomp
26
27namespace {
28class LowerNontemporalPass
29 : public flangomp::impl::LowerNontemporalPassBase<LowerNontemporalPass> {
30 void addNonTemporalAttr(omp::SimdOp simdOp) {
31 if (simdOp.getNontemporalVars().empty())
32 return;
33
34 std::function<mlir::Value(mlir::Value)> getBaseOperand =
35 [&](mlir::Value operand) -> mlir::Value {
36 auto *defOp = operand.getDefiningOp();
37 while (defOp) {
38 llvm::TypeSwitch<Operation *>(defOp)
39 .Case<fir::ArrayCoorOp, fir::cg::XArrayCoorOp, fir::LoadOp>(
40 [&](auto op) {
41 operand = op.getMemref();
42 defOp = operand.getDefiningOp();
43 })
44 .Case<fir::BoxAddrOp>([&](auto op) {
45 operand = op.getVal();
46 defOp = operand.getDefiningOp();
47 })
48 .Default([&](auto op) { defOp = nullptr; });
49 }
50 return operand;
51 };
52
53 // walk through the operations and mark the load and store as nontemporal
54 simdOp->walk([&](Operation *op) {
55 mlir::Value operand = nullptr;
56
57 if (auto loadOp = llvm::dyn_cast<fir::LoadOp>(op))
58 operand = loadOp.getMemref();
59 else if (auto storeOp = llvm::dyn_cast<fir::StoreOp>(op))
60 operand = storeOp.getMemref();
61
62 // Skip load and store operations involving boxes (allocatable or pointer
63 // types).
64 if (operand && !(fir::isAllocatableType(operand.getType()) ||
65 fir::isPointerType((operand.getType())))) {
66 operand = getBaseOperand(operand);
67
68 // TODO : Handling of nontemporal clause inside atomic construct
69 if (llvm::is_contained(simdOp.getNontemporalVars(), operand)) {
70 if (auto loadOp = llvm::dyn_cast<fir::LoadOp>(op))
71 loadOp.setNontemporal(true);
72 else if (auto storeOp = llvm::dyn_cast<fir::StoreOp>(op))
73 storeOp.setNontemporal(true);
74 }
75 }
76 });
77 }
78
79 void runOnOperation() override {
80 Operation *op = getOperation();
81 op->walk([&](omp::SimdOp simdOp) { addNonTemporalAttr(simdOp); });
82 }
83};
84} // namespace
85

source code of flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp