aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergio Afonso <safonsof@amd.com>2025-08-13 12:53:34 +0100
committerSergio Afonso <safonsof@amd.com>2025-08-13 12:53:34 +0100
commit9e948a58af729d8c142c6e1c4a252a01fd2e6dbd (patch)
treeea7ea8db51de54431e7d0ec33c38c741928a9422
parent39800face19f2966a4456d2a4c583cd87a693c7e (diff)
downloadllvm-users/skatrak/flang-generic-01-mlir-pattern.zip
llvm-users/skatrak/flang-generic-01-mlir-pattern.tar.gz
llvm-users/skatrak/flang-generic-01-mlir-pattern.tar.bz2
Update TargetRegionFlags to mirror OMPTgtExecModeFlagsusers/skatrak/flang-generic-01-mlir-pattern
-rw-r--r--mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td20
-rw-r--r--mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td12
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp51
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp22
4 files changed, 64 insertions, 41 deletions
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index ce0ebabd..deb2fba 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -223,19 +223,19 @@ def ScheduleModifier : OpenMP_I32EnumAttr<
def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
//===----------------------------------------------------------------------===//
-// target_region_flags enum.
+// target_exec_mode enum.
//===----------------------------------------------------------------------===//
-def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
-def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 0>;
-def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 1>;
+def TargetExecModeBare : I32EnumAttrCase<"bare", 0>;
+def TargetExecModeGeneric : I32EnumAttrCase<"generic", 1>;
+def TargetExecModeSpmd : I32EnumAttrCase<"spmd", 2>;
-def TargetRegionFlags : OpenMP_BitEnumAttr<
- "TargetRegionFlags",
- "target region property flags", [
- TargetRegionFlagsNone,
- TargetRegionFlagsSpmd,
- TargetRegionFlagsTripCount
+def TargetExecMode : OpenMP_I32EnumAttr<
+ "TargetExecMode",
+ "target execution mode, mirroring the `OMPTgtExecModeFlags` LLVM enum", [
+ TargetExecModeBare,
+ TargetExecModeGeneric,
+ TargetExecModeSpmd,
]>;
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index be114ea..6569905 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1517,13 +1517,17 @@ def TargetOp : OpenMP_Op<"target", traits = [
/// operations, the top level one will be the one captured.
Operation *getInnermostCapturedOmpOp();
- /// Infers the kernel type (Generic, SPMD or Generic-SPMD) based on the
- /// contents of the target region.
+ /// Infers the kernel type (Bare, Generic or SPMD) based on the contents of
+ /// the target region.
///
/// \param capturedOp result of a still valid (no modifications made to any
/// nested operations) previous call to `getInnermostCapturedOmpOp()`.
- static ::mlir::omp::TargetRegionFlags
- getKernelExecFlags(Operation *capturedOp);
+ /// \param hostEvalTripCount output argument to store whether this kernel
+ /// wraps a loop whose bounds must be evaluated on the host prior to
+ /// launching it.
+ static ::mlir::omp::TargetExecMode
+ getKernelExecFlags(Operation *capturedOp,
+ bool *hostEvalTripCount = nullptr);
}] # clausesExtraClassDeclaration;
let assemblyFormat = clausesAssemblyFormat # [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 8854e90..c3c1700 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1974,8 +1974,9 @@ LogicalResult TargetOp::verifyRegions() {
return emitError("target containing multiple 'omp.teams' nested ops");
// Check that host_eval values are only used in legal ways.
+ bool hostEvalTripCount;
Operation *capturedOp = getInnermostCapturedOmpOp();
- TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
+ TargetExecMode execMode = getKernelExecFlags(capturedOp, &hostEvalTripCount);
for (Value hostEvalArg :
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
for (Operation *user : hostEvalArg.getUsers()) {
@@ -1990,7 +1991,7 @@ LogicalResult TargetOp::verifyRegions() {
"and 'thread_limit' in 'omp.teams'";
}
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
- if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
+ if (execMode == TargetExecMode::spmd &&
parallelOp->isAncestor(capturedOp) &&
hostEvalArg == parallelOp.getNumThreads())
continue;
@@ -2000,8 +2001,7 @@ LogicalResult TargetOp::verifyRegions() {
"'omp.parallel' when representing target SPMD";
}
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
- if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
- loopNestOp.getOperation() == capturedOp &&
+ if (hostEvalTripCount && loopNestOp.getOperation() == capturedOp &&
(llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
@@ -2106,7 +2106,9 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
});
}
-TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
+TargetExecMode TargetOp::getKernelExecFlags(Operation *capturedOp,
+ bool *hostEvalTripCount) {
+ // TODO: Support detection of bare kernel mode.
// A non-null captured op is only valid if it resides inside of a TargetOp
// and is the result of calling getInnermostCapturedOmpOp() on it.
TargetOp targetOp =
@@ -2115,9 +2117,12 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
(targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
"unexpected captured op");
+ if (hostEvalTripCount)
+ *hostEvalTripCount = false;
+
// If it's not capturing a loop, it's a default target region.
if (!isa_and_present<LoopNestOp>(capturedOp))
- return TargetRegionFlags::none;
+ return TargetExecMode::generic;
// Get the innermost non-simd loop wrapper.
SmallVector<LoopWrapperInterface> loopWrappers;
@@ -2130,53 +2135,59 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
if (numWrappers != 1 && numWrappers != 2)
- return TargetRegionFlags::none;
+ return TargetExecMode::generic;
// Detect target-teams-distribute-parallel-wsloop[-simd].
if (numWrappers == 2) {
if (!isa<WsloopOp>(innermostWrapper))
- return TargetRegionFlags::none;
+ return TargetExecMode::generic;
innermostWrapper = std::next(innermostWrapper);
if (!isa<DistributeOp>(innermostWrapper))
- return TargetRegionFlags::none;
+ return TargetExecMode::generic;
Operation *parallelOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
- return TargetRegionFlags::none;
+ return TargetExecMode::generic;
Operation *teamsOp = parallelOp->getParentOp();
if (!isa_and_present<TeamsOp>(teamsOp))
- return TargetRegionFlags::none;
+ return TargetExecMode::generic;
- if (teamsOp->getParentOp() == targetOp.getOperation())
- return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
+ if (teamsOp->getParentOp() == targetOp.getOperation()) {
+ if (hostEvalTripCount)
+ *hostEvalTripCount = true;
+ return TargetExecMode::spmd;
+ }
}
// Detect target-teams-distribute[-simd] and target-teams-loop.
else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
Operation *teamsOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<TeamsOp>(teamsOp))
- return TargetRegionFlags::none;
+ return TargetExecMode::generic;
if (teamsOp->getParentOp() != targetOp.getOperation())
- return TargetRegionFlags::none;
+ return TargetExecMode::generic;
+
+ if (hostEvalTripCount)
+ *hostEvalTripCount = true;
if (isa<LoopOp>(innermostWrapper))
- return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
+ return TargetExecMode::spmd;
- return TargetRegionFlags::trip_count;
+ return TargetExecMode::generic;
}
// Detect target-parallel-wsloop[-simd].
else if (isa<WsloopOp>(innermostWrapper)) {
Operation *parallelOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
- return TargetRegionFlags::none;
+ return TargetExecMode::generic;
if (parallelOp->getParentOp() == targetOp.getOperation())
- return TargetRegionFlags::spmd;
+ return TargetExecMode::spmd;
}
- return TargetRegionFlags::none;
+ return TargetExecMode::generic;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 88601ef..d49cc38 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -5354,11 +5354,18 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
}
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
- omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
- attrs.ExecFlags =
- omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
- ? llvm::omp::OMP_TGT_EXEC_MODE_SPMD
- : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
+ omp::TargetExecMode execMode = targetOp.getKernelExecFlags(capturedOp);
+ switch (execMode) {
+ case omp::TargetExecMode::bare:
+ attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_BARE;
+ break;
+ case omp::TargetExecMode::generic:
+ attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
+ break;
+ case omp::TargetExecMode::spmd:
+ attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
+ break;
+ }
attrs.MinTeams = minTeamsVal;
attrs.MaxTeams.front() = maxTeamsVal;
attrs.MinThreads = 1;
@@ -5408,8 +5415,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
if (numThreads)
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
- if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
- omp::TargetRegionFlags::trip_count)) {
+ bool hostEvalTripCount;
+ targetOp.getKernelExecFlags(capturedOp, &hostEvalTripCount);
+ if (hostEvalTripCount) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
attrs.LoopTripCount = nullptr;