aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergio Afonso <safonsof@amd.com>2024-04-19 16:15:10 +0100
committerGitHub <noreply@github.com>2024-04-19 16:15:10 +0100
commit5e5b8c49096afba8e4e0fd094a5ab905a9acced0 (patch)
tree10ba39e5ba4b4bf013a3354dbe48cb931779a93b
parent9dbf3e2384e450c2b4f282b85b9ec47c65976194 (diff)
downloadllvm-5e5b8c49096afba8e4e0fd094a5ab905a9acced0.zip
llvm-5e5b8c49096afba8e4e0fd094a5ab905a9acced0.tar.gz
llvm-5e5b8c49096afba8e4e0fd094a5ab905a9acced0.tar.bz2
[MLIR][OpenMP] Verify loop wrapper properties of omp.parallel (#88722)
This patch extends verification of the `omp.parallel` operation to check it is correctly defined when taking a loop wrapper role. In OpenMP, a PARALLEL construct can be either a (potenially combined) block construct or a loop construct, when appearing as part of a composite construct. This is currently the case for the DISTRIBUTE PARALLEL DO/FOR and DISTRIBUTE PARALLEL DO/FOR SIMD exclusively. When used to represent the PARALLEL leaf of a composite construct, it must follow the rules of a wrapper loop operation in MLIR, and this is what this patch ensures. No additional restrictions are introduced for PARALLEL block constructs.
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp16
-rw-r--r--mlir/test/Dialect/OpenMP/invalid.mlir52
-rw-r--r--mlir/test/Dialect/OpenMP/ops.mlir20
3 files changed, 87 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index f380926..528a0d0 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1344,6 +1344,22 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
}
LogicalResult ParallelOp::verify() {
+ // Check that it is a valid loop wrapper if it's taking that role.
+ if (isa<DistributeOp>((*this)->getParentOp())) {
+ if (!isWrapper())
+ return emitOpError() << "must take a loop wrapper role if nested inside "
+ "of 'omp.distribute'";
+
+ if (LoopWrapperInterface nested = getNestedWrapper()) {
+ // Check for the allowed leaf constructs that may appear in a composite
+ // construct directly after PARALLEL.
+ if (!isa<WsloopOp>(nested))
+ return emitError() << "only supported nested wrapper is 'omp.wsloop'";
+ } else {
+ return emitOpError() << "must not wrap an 'omp.loop_nest' directly";
+ }
+ }
+
if (getAllocateVars().size() != getAllocatorsVars().size())
return emitError(
"expected equal sizes for allocate and allocator variables");
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 1f04f45..2f24dce 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -10,6 +10,58 @@ func.func @unknown_clause() {
// -----
+func.func @not_wrapper() {
+ omp.distribute {
+ // expected-error@+1 {{op must take a loop wrapper role if nested inside of 'omp.distribute'}}
+ omp.parallel {
+ %0 = arith.constant 0 : i32
+ omp.terminator
+ }
+ omp.terminator
+ }
+
+ return
+}
+
+// -----
+
+func.func @invalid_nested_wrapper(%lb : index, %ub : index, %step : index) {
+ omp.distribute {
+ // expected-error@+1 {{only supported nested wrapper is 'omp.wsloop'}}
+ omp.parallel {
+ omp.simd {
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+ omp.yield
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+
+ return
+}
+
+// -----
+
+func.func @no_nested_wrapper(%lb : index, %ub : index, %step : index) {
+ omp.distribute {
+ // expected-error@+1 {{op must not wrap an 'omp.loop_nest' directly}}
+ omp.parallel {
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+ omp.yield
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+
+ return
+}
+
+// -----
+
func.func @if_once(%n : i1) {
// expected-error@+1 {{`if` clause can appear at most once in the expansion of the oilist directive}}
omp.parallel if(%n : i1) if(%n : i1) {
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index e2ca12a..c10fc88 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -51,7 +51,7 @@ func.func @omp_terminator() -> () {
omp.terminator
}
-func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i32) -> () {
+func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i32, %idx : index) -> () {
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel" (%if_cond, %num_threads, %data_var, %data_var) ({
@@ -85,6 +85,24 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
omp.terminator
}) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (memref<i32>, memref<i32>) -> ()
+ // CHECK: omp.distribute
+ omp.distribute {
+ // CHECK-NEXT: omp.parallel
+ omp.parallel {
+ // CHECK-NEXT: omp.wsloop
+ // TODO Remove induction variables from omp.wsloop.
+ omp.wsloop for (%iv) : index = (%idx) to (%idx) step (%idx) {
+ // CHECK-NEXT: omp.loop_nest
+ omp.loop_nest (%iv2) : index = (%idx) to (%idx) step (%idx) {
+ omp.yield
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+
return
}