From 5e5b8c49096afba8e4e0fd094a5ab905a9acced0 Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Fri, 19 Apr 2024 16:15:10 +0100 Subject: [MLIR][OpenMP] Verify loop wrapper properties of omp.parallel (#88722) This patch extends verification of the `omp.parallel` operation to check it is correctly defined when taking a loop wrapper role. In OpenMP, a PARALLEL construct can be either a (potenially combined) block construct or a loop construct, when appearing as part of a composite construct. This is currently the case for the DISTRIBUTE PARALLEL DO/FOR and DISTRIBUTE PARALLEL DO/FOR SIMD exclusively. When used to represent the PARALLEL leaf of a composite construct, it must follow the rules of a wrapper loop operation in MLIR, and this is what this patch ensures. No additional restrictions are introduced for PARALLEL block constructs. --- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 16 +++++++++ mlir/test/Dialect/OpenMP/invalid.mlir | 52 ++++++++++++++++++++++++++++ mlir/test/Dialect/OpenMP/ops.mlir | 20 ++++++++++- 3 files changed, 87 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index f380926..528a0d0 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1344,6 +1344,22 @@ static LogicalResult verifyPrivateVarList(OpType &op) { } LogicalResult ParallelOp::verify() { + // Check that it is a valid loop wrapper if it's taking that role. + if (isa((*this)->getParentOp())) { + if (!isWrapper()) + return emitOpError() << "must take a loop wrapper role if nested inside " + "of 'omp.distribute'"; + + if (LoopWrapperInterface nested = getNestedWrapper()) { + // Check for the allowed leaf constructs that may appear in a composite + // construct directly after PARALLEL. + if (!isa(nested)) + return emitError() << "only supported nested wrapper is 'omp.wsloop'"; + } else { + return emitOpError() << "must not wrap an 'omp.loop_nest' directly"; + } + } + if (getAllocateVars().size() != getAllocatorsVars().size()) return emitError( "expected equal sizes for allocate and allocator variables"); diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 1f04f45..2f24dce 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -10,6 +10,58 @@ func.func @unknown_clause() { // ----- +func.func @not_wrapper() { + omp.distribute { + // expected-error@+1 {{op must take a loop wrapper role if nested inside of 'omp.distribute'}} + omp.parallel { + %0 = arith.constant 0 : i32 + omp.terminator + } + omp.terminator + } + + return +} + +// ----- + +func.func @invalid_nested_wrapper(%lb : index, %ub : index, %step : index) { + omp.distribute { + // expected-error@+1 {{only supported nested wrapper is 'omp.wsloop'}} + omp.parallel { + omp.simd { + omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) { + omp.yield + } + omp.terminator + } + omp.terminator + } + omp.terminator + } + + return +} + +// ----- + +func.func @no_nested_wrapper(%lb : index, %ub : index, %step : index) { + omp.distribute { + // expected-error@+1 {{op must not wrap an 'omp.loop_nest' directly}} + omp.parallel { + omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) { + omp.yield + } + omp.terminator + } + omp.terminator + } + + return +} + +// ----- + func.func @if_once(%n : i1) { // expected-error@+1 {{`if` clause can appear at most once in the expansion of the oilist directive}} omp.parallel if(%n : i1) if(%n : i1) { diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index e2ca12a..c10fc88 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -51,7 +51,7 @@ func.func @omp_terminator() -> () { omp.terminator } -func.func @omp_parallel(%data_var : memref, %if_cond : i1, %num_threads : i32) -> () { +func.func @omp_parallel(%data_var : memref, %if_cond : i1, %num_threads : i32, %idx : index) -> () { // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref -> %{{.*}} : memref) "omp.parallel" (%if_cond, %num_threads, %data_var, %data_var) ({ @@ -85,6 +85,24 @@ func.func @omp_parallel(%data_var : memref, %if_cond : i1, %num_threads : i omp.terminator }) {operandSegmentSizes = array} : (memref, memref) -> () + // CHECK: omp.distribute + omp.distribute { + // CHECK-NEXT: omp.parallel + omp.parallel { + // CHECK-NEXT: omp.wsloop + // TODO Remove induction variables from omp.wsloop. + omp.wsloop for (%iv) : index = (%idx) to (%idx) step (%idx) { + // CHECK-NEXT: omp.loop_nest + omp.loop_nest (%iv2) : index = (%idx) to (%idx) step (%idx) { + omp.yield + } + omp.terminator + } + omp.terminator + } + omp.terminator + } + return } -- cgit v1.1