blob: 5aa1273a1be3684604d462eaff905c073527f978 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
|
//===- LowerNontemporal.cpp -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Add nontemporal attributes to load and stores of variables marked as
// nontemporal.
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Dialect/FIRCG/CGOps.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/OpenMP/Passes.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
namespace flangomp {
#define GEN_PASS_DEF_LOWERNONTEMPORALPASS
#include "flang/Optimizer/OpenMP/Passes.h.inc"
} // namespace flangomp
namespace {
class LowerNontemporalPass
: public flangomp::impl::LowerNontemporalPassBase<LowerNontemporalPass> {
void addNonTemporalAttr(omp::SimdOp simdOp) {
if (simdOp.getNontemporalVars().empty())
return;
std::function<mlir::Value(mlir::Value)> getBaseOperand =
[&](mlir::Value operand) -> mlir::Value {
auto *defOp = operand.getDefiningOp();
while (defOp) {
llvm::TypeSwitch<Operation *>(defOp)
.Case<fir::ArrayCoorOp, fir::cg::XArrayCoorOp, fir::LoadOp>(
[&](auto op) {
operand = op.getMemref();
defOp = operand.getDefiningOp();
})
.Case<fir::BoxAddrOp>([&](auto op) {
operand = op.getVal();
defOp = operand.getDefiningOp();
})
.Default([&](auto op) { defOp = nullptr; });
}
return operand;
};
// walk through the operations and mark the load and store as nontemporal
simdOp->walk([&](Operation *op) {
mlir::Value operand = nullptr;
if (auto loadOp = llvm::dyn_cast<fir::LoadOp>(op))
operand = loadOp.getMemref();
else if (auto storeOp = llvm::dyn_cast<fir::StoreOp>(op))
operand = storeOp.getMemref();
// Skip load and store operations involving boxes (allocatable or pointer
// types).
if (operand && !(fir::isAllocatableType(operand.getType()) ||
fir::isPointerType((operand.getType())))) {
operand = getBaseOperand(operand);
// TODO : Handling of nontemporal clause inside atomic construct
if (llvm::is_contained(simdOp.getNontemporalVars(), operand)) {
if (auto loadOp = llvm::dyn_cast<fir::LoadOp>(op))
loadOp.setNontemporal(true);
else if (auto storeOp = llvm::dyn_cast<fir::StoreOp>(op))
storeOp.setNontemporal(true);
}
}
});
}
void runOnOperation() override {
Operation *op = getOperation();
op->walk([&](omp::SimdOp simdOp) { addNonTemporalAttr(simdOp); });
}
};
} // namespace
|