aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorValentin Clement (バレンタイン クレメン) <clementval@gmail.com>2024-01-22 08:40:52 -0800
committerGitHub <noreply@github.com>2024-01-22 08:40:52 -0800
commitee6199ca3cf101c764788ebf8df5b0e3e00f5538 (patch)
treeccc41bd23215b28d690b3b40ac9e0197028ffcc6 /mlir
parentb5df6a90f5365e61d2dfa1583d36cbc79ab5775b (diff)
downloadllvm-ee6199ca3cf101c764788ebf8df5b0e3e00f5538.zip
llvm-ee6199ca3cf101c764788ebf8df5b0e3e00f5538.tar.gz
llvm-ee6199ca3cf101c764788ebf8df5b0e3e00f5538.tar.bz2
[mlir][openacc][NFC] Cleanup hasOnly functions for device_type support (#78800)
Just a cleanup for all the `has.*Only()` function to avoid code duplication
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp150
1 files changed, 49 insertions, 101 deletions
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 80f0529..bdc9c34 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -70,6 +70,41 @@ void OpenACCDialect::initialize() {
}
//===----------------------------------------------------------------------===//
+// device_type support helpers
+//===----------------------------------------------------------------------===//
+
+static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
+ if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
+ return true;
+ return false;
+}
+
+static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
+ mlir::acc::DeviceType deviceType) {
+ if (!hasDeviceTypeValues(arrayAttr))
+ return false;
+
+ for (auto attr : *arrayAttr) {
+ auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+ if (deviceTypeAttr.getValue() == deviceType)
+ return true;
+ }
+
+ return false;
+}
+
+static void printDeviceTypes(mlir::OpAsmPrinter &p,
+ std::optional<mlir::ArrayAttr> deviceTypes) {
+ if (!hasDeviceTypeValues(deviceTypes))
+ return;
+
+ p << "[";
+ llvm::interleaveComma(*deviceTypes, p,
+ [&](mlir::Attribute attr) { p << attr; });
+ p << "]";
+}
+
+//===----------------------------------------------------------------------===//
// DataBoundsOp
//===----------------------------------------------------------------------===//
LogicalResult acc::DataBoundsOp::verify() {
@@ -722,11 +757,7 @@ bool acc::ParallelOp::hasAsyncOnly() {
}
bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getAsyncOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getAsyncOnly(), deviceType);
}
mlir::Value acc::ParallelOp::getAsyncValue() {
@@ -789,11 +820,7 @@ bool acc::ParallelOp::hasWaitOnly() {
}
bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getWaitOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getWaitOnly(), deviceType);
}
mlir::Operation::operand_range ParallelOp::getWaitValues() {
@@ -1033,23 +1060,6 @@ static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
return success();
}
-static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
- if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
- return true;
- return false;
-}
-
-static void printDeviceTypes(mlir::OpAsmPrinter &p,
- std::optional<mlir::ArrayAttr> deviceTypes) {
- if (!hasDeviceTypeValues(deviceTypes))
- return;
-
- p << "[";
- llvm::interleaveComma(*deviceTypes, p,
- [&](mlir::Attribute attr) { p << attr; });
- p << "]";
-}
-
static void printDeviceTypeOperandsWithKeywordOnly(
mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
@@ -1093,11 +1103,7 @@ bool acc::SerialOp::hasAsyncOnly() {
}
bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getAsyncOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getAsyncOnly(), deviceType);
}
mlir::Value acc::SerialOp::getAsyncValue() {
@@ -1114,11 +1120,7 @@ bool acc::SerialOp::hasWaitOnly() {
}
bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getWaitOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getWaitOnly(), deviceType);
}
mlir::Operation::operand_range SerialOp::getWaitValues() {
@@ -1177,11 +1179,7 @@ bool acc::KernelsOp::hasAsyncOnly() {
}
bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getAsyncOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getAsyncOnly(), deviceType);
}
mlir::Value acc::KernelsOp::getAsyncValue() {
@@ -1228,11 +1226,7 @@ bool acc::KernelsOp::hasWaitOnly() {
}
bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getWaitOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getWaitOnly(), deviceType);
}
mlir::Operation::operand_range KernelsOp::getWaitValues() {
@@ -1646,11 +1640,7 @@ Value LoopOp::getDataOperand(unsigned i) {
bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getAuto_()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getAuto_(), deviceType);
}
bool LoopOp::hasIndependent() {
@@ -1658,21 +1648,13 @@ bool LoopOp::hasIndependent() {
}
bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getIndependent()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getIndependent(), deviceType);
}
bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getSeq()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getSeq(), deviceType);
}
mlir::Value LoopOp::getVectorValue() {
@@ -1687,11 +1669,7 @@ mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getVector()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getVector(), deviceType);
}
mlir::Value LoopOp::getWorkerValue() {
@@ -1706,11 +1684,7 @@ mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getWorker()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getWorker(), deviceType);
}
mlir::Operation::operand_range LoopOp::getTileValues() {
@@ -1771,11 +1745,7 @@ mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getGang()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getGang(), deviceType);
}
//===----------------------------------------------------------------------===//
@@ -1815,11 +1785,7 @@ bool acc::DataOp::hasAsyncOnly() {
}
bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getAsyncOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getAsyncOnly(), deviceType);
}
mlir::Value DataOp::getAsyncValue() {
@@ -1834,11 +1800,7 @@ mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getWaitOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getWaitOnly(), deviceType);
}
mlir::Operation::operand_range DataOp::getWaitValues() {
@@ -2091,20 +2053,6 @@ LogicalResult acc::DeclareOp::verify() {
// RoutineOp
//===----------------------------------------------------------------------===//
-static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
- mlir::acc::DeviceType deviceType) {
- if (!hasDeviceTypeValues(arrayAttr))
- return false;
-
- for (auto attr : *arrayAttr) {
- auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
- if (deviceTypeAttr.getValue() == deviceType)
- return true;
- }
-
- return false;
-}
-
static unsigned getParallelismForDeviceType(acc::RoutineOp op,
acc::DeviceType dtype) {
unsigned parallelism = 0;