diff options
author | Valentin Clement (バレンタイン クレメン) <clementval@gmail.com> | 2024-01-22 08:40:52 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-22 08:40:52 -0800 |
commit | ee6199ca3cf101c764788ebf8df5b0e3e00f5538 (patch) | |
tree | ccc41bd23215b28d690b3b40ac9e0197028ffcc6 | |
parent | b5df6a90f5365e61d2dfa1583d36cbc79ab5775b (diff) | |
download | llvm-ee6199ca3cf101c764788ebf8df5b0e3e00f5538.zip llvm-ee6199ca3cf101c764788ebf8df5b0e3e00f5538.tar.gz llvm-ee6199ca3cf101c764788ebf8df5b0e3e00f5538.tar.bz2 |
[mlir][openacc][NFC] Cleanup hasOnly functions for device_type support (#78800)
Just a cleanup for all the `has.*Only()` function to avoid code
duplication
-rw-r--r-- | mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 150 |
1 files changed, 49 insertions, 101 deletions
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 80f0529..bdc9c34 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -70,6 +70,41 @@ void OpenACCDialect::initialize() { } //===----------------------------------------------------------------------===// +// device_type support helpers +//===----------------------------------------------------------------------===// + +static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) { + if (arrayAttr && *arrayAttr && arrayAttr->size() > 0) + return true; + return false; +} + +static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr, + mlir::acc::DeviceType deviceType) { + if (!hasDeviceTypeValues(arrayAttr)) + return false; + + for (auto attr : *arrayAttr) { + auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr); + if (deviceTypeAttr.getValue() == deviceType) + return true; + } + + return false; +} + +static void printDeviceTypes(mlir::OpAsmPrinter &p, + std::optional<mlir::ArrayAttr> deviceTypes) { + if (!hasDeviceTypeValues(deviceTypes)) + return; + + p << "["; + llvm::interleaveComma(*deviceTypes, p, + [&](mlir::Attribute attr) { p << attr; }); + p << "]"; +} + +//===----------------------------------------------------------------------===// // DataBoundsOp //===----------------------------------------------------------------------===// LogicalResult acc::DataBoundsOp::verify() { @@ -722,11 +757,7 @@ bool acc::ParallelOp::hasAsyncOnly() { } bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { - if (auto arrayAttr = getAsyncOnly()) { - if (findSegment(*arrayAttr, deviceType)) - return true; - } - return false; + return hasDeviceType(getAsyncOnly(), deviceType); } mlir::Value acc::ParallelOp::getAsyncValue() { @@ -789,11 +820,7 @@ bool acc::ParallelOp::hasWaitOnly() { } bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { - if (auto arrayAttr = getWaitOnly()) { - if (findSegment(*arrayAttr, deviceType)) - return true; - } - return false; + return hasDeviceType(getWaitOnly(), deviceType); } mlir::Operation::operand_range ParallelOp::getWaitValues() { @@ -1033,23 +1060,6 @@ static ParseResult parseDeviceTypeOperandsWithKeywordOnly( return success(); } -static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) { - if (arrayAttr && *arrayAttr && arrayAttr->size() > 0) - return true; - return false; -} - -static void printDeviceTypes(mlir::OpAsmPrinter &p, - std::optional<mlir::ArrayAttr> deviceTypes) { - if (!hasDeviceTypeValues(deviceTypes)) - return; - - p << "["; - llvm::interleaveComma(*deviceTypes, p, - [&](mlir::Attribute attr) { p << attr; }); - p << "]"; -} - static void printDeviceTypeOperandsWithKeywordOnly( mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes, @@ -1093,11 +1103,7 @@ bool acc::SerialOp::hasAsyncOnly() { } bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { - if (auto arrayAttr = getAsyncOnly()) { - if (findSegment(*arrayAttr, deviceType)) - return true; - } - return false; + return hasDeviceType(getAsyncOnly(), deviceType); } mlir::Value acc::SerialOp::getAsyncValue() { @@ -1114,11 +1120,7 @@ bool acc::SerialOp::hasWaitOnly() { } bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { - if (auto arrayAttr = getWaitOnly()) { - if (findSegment(*arrayAttr, deviceType)) - return true; - } - return false; + return hasDeviceType(getWaitOnly(), deviceType); } mlir::Operation::operand_range SerialOp::getWaitValues() { @@ -1177,11 +1179,7 @@ bool acc::KernelsOp::hasAsyncOnly() { } bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { - if (auto arrayAttr = getAsyncOnly()) { - if (findSegment(*arrayAttr, deviceType)) - return true; - } - return false; + return hasDeviceType(getAsyncOnly(), deviceType); } mlir::Value acc::KernelsOp::getAsyncValue() { @@ -1228,11 +1226,7 @@ bool acc::KernelsOp::hasWaitOnly() { } bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { - if (auto arrayAttr = getWaitOnly()) { - if (findSegment(*arrayAttr, deviceType)) - return true; - } - return false; + return hasDeviceType(getWaitOnly(), deviceType); } mlir::Operation::operand_range KernelsOp::getWaitValues() { @@ -1646,11 +1640,7 @@ Value LoopOp::getDataOperand(unsigned i) { bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); } bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) { - if (auto arrayAttr = getAuto_()) { - if (findSegment(*arrayAttr, deviceType)) - return true; - } - return false; + return hasDeviceType(getAuto_(), deviceType); } bool LoopOp::hasIndependent() { @@ -1658,21 +1648,13 @@ bool LoopOp::hasIndependent() { } bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) { - if (auto arrayAttr = getIndependent()) { - if (findSegment(*arrayAttr, deviceType)) - return true; - } - return false; + return hasDeviceType(getIndependent(), deviceType); } bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); } bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) { - if (auto arrayAttr = getSeq()) { - if (findSegment(*arrayAttr, deviceType)) - return true; - } - return false; + return hasDeviceType(getSeq(), deviceType); } mlir::Value LoopOp::getVectorValue() { @@ -1687,11 +1669,7 @@ mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) { bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); } bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) { - if (auto arrayAttr = getVector()) { - if (findSegment(*arrayAttr, deviceType)) - return true; - } - return false; + return hasDeviceType(getVector(), deviceType); } mlir::Value LoopOp::getWorkerValue() { @@ -1706,11 +1684,7 @@ mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) { bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); } bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) { - if (auto arrayAttr = getWorker()) { - if (findSegment(*arrayAttr, deviceType)) - return true; - } - return false; + return hasDeviceType(getWorker(), deviceType); } mlir::Operation::operand_range LoopOp::getTileValues() { @@ -1771,11 +1745,7 @@ mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType, bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); } bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) { - if (auto arrayAttr = getGang()) { - if (findSegment(*arrayAttr, deviceType)) - return true; - } - return false; + return hasDeviceType(getGang(), deviceType); } //===----------------------------------------------------------------------===// @@ -1815,11 +1785,7 @@ bool acc::DataOp::hasAsyncOnly() { } bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { - if (auto arrayAttr = getAsyncOnly()) { - if (findSegment(*arrayAttr, deviceType)) - return true; - } - return false; + return hasDeviceType(getAsyncOnly(), deviceType); } mlir::Value DataOp::getAsyncValue() { @@ -1834,11 +1800,7 @@ mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) { bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); } bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { - if (auto arrayAttr = getWaitOnly()) { - if (findSegment(*arrayAttr, deviceType)) - return true; - } - return false; + return hasDeviceType(getWaitOnly(), deviceType); } mlir::Operation::operand_range DataOp::getWaitValues() { @@ -2091,20 +2053,6 @@ LogicalResult acc::DeclareOp::verify() { // RoutineOp //===----------------------------------------------------------------------===// -static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr, - mlir::acc::DeviceType deviceType) { - if (!hasDeviceTypeValues(arrayAttr)) - return false; - - for (auto attr : *arrayAttr) { - auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr); - if (deviceTypeAttr.getValue() == deviceType) - return true; - } - - return false; -} - static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype) { unsigned parallelism = 0; |