aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Conversion/Passes.h1
-rw-r--r--mlir/include/mlir/Conversion/Passes.td11
-rw-r--r--mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h23
-rw-r--r--mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td9
-rw-r--r--mlir/lib/Conversion/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/PassDetail.h4
-rw-r--r--mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt17
-rw-r--r--mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp113
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp28
-rw-r--r--mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir65
10 files changed, 272 insertions, 0 deletions
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index b1d2da9..0d2281f 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -23,6 +23,7 @@
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
+#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h"
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 7c61fad..fdf01b7 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -231,6 +231,17 @@ def ConvertPDLToPDLInterp : Pass<"convert-pdl-to-pdl-interp", "ModuleOp"> {
}
//===----------------------------------------------------------------------===//
+// SCFToOpenMP
+//===----------------------------------------------------------------------===//
+
+def ConvertSCFToOpenMP : FunctionPass<"convert-scf-to-openmp"> {
+ let summary = "Convert SCF parallel loop to OpenMP parallel + workshare "
+ "constructs.";
+ let constructor = "mlir::createConvertSCFToOpenMPPass()";
+ let dependentDialects = ["omp::OpenMPDialect"];
+}
+
+//===----------------------------------------------------------------------===//
// SCFToStandard
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h b/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h
new file mode 100644
index 0000000..349c4e1
--- /dev/null
+++ b/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h
@@ -0,0 +1,23 @@
+//===- ConvertSCFToOpenMP.h - SCF to OpenMP pass entrypoint -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_SCFTOOPENMP_SCFTOOPENMP_H
+#define MLIR_CONVERSION_SCFTOOPENMP_SCFTOOPENMP_H
+
+#include <memory>
+
+namespace mlir {
+class FuncOp;
+template <typename T>
+class OperationPass;
+
+std::unique_ptr<OperationPass<FuncOp>> createConvertSCFToOpenMPPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_SCFTOOPENMP_SCFTOOPENMP_H
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index d42466a..f915afc 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -92,6 +92,9 @@ def ParallelOp : OpenMP_Op<"parallel", [AttrSizedOperandSegments]> {
let regions = (region AnyRegion:$region);
+ let builders = [
+ OpBuilderDAG<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
+ ];
let parser = [{ return parseParallelOp(parser, result); }];
let printer = [{ return printParallelOp(p, *this); }];
let verifier = [{ return ::verifyParallelOp(*this); }];
@@ -175,6 +178,12 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments]> {
Confined<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$ordered_val,
OptionalAttr<OrderKind>:$order_val);
+ let builders = [
+ OpBuilderDAG<(ins "ValueRange":$lowerBound, "ValueRange":$upperBound,
+ "ValueRange":$step,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
+ ];
+
let regions = (region AnyRegion:$region);
}
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index fa402cc..bf17895 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -12,6 +12,7 @@ add_subdirectory(LinalgToStandard)
add_subdirectory(OpenMPToLLVM)
add_subdirectory(PDLToPDLInterp)
add_subdirectory(SCFToGPU)
+add_subdirectory(SCFToOpenMP)
add_subdirectory(SCFToSPIRV)
add_subdirectory(SCFToStandard)
add_subdirectory(ShapeToStandard)
diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h
index bee2f57..6314a5c 100644
--- a/mlir/lib/Conversion/PassDetail.h
+++ b/mlir/lib/Conversion/PassDetail.h
@@ -33,6 +33,10 @@ namespace NVVM {
class NVVMDialect;
} // end namespace NVVM
+namespace omp {
+class OpenMPDialect;
+} // end namespace omp
+
namespace pdl_interp {
class PDLInterpDialect;
} // end namespace pdl_interp
diff --git a/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt b/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt
new file mode 100644
index 0000000..1ef4b74
--- /dev/null
+++ b/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(MLIRSCFToOpenMP
+ SCFToOpenMP.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToStandard
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIROpenMP
+ MLIRSCF
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
new file mode 100644
index 0000000..01e7623
--- /dev/null
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -0,0 +1,113 @@
+//===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert scf.parallel operations into OpenMP
+// parallel loops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
+#include "../PassDetail.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Converts SCF parallel operation into an OpenMP workshare loop construct.
+struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
+ using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
+ PatternRewriter &rewriter) const override {
+ // TODO: add support for reductions when OpenMP loops have them.
+ if (parallelOp.getNumResults() != 0)
+ return rewriter.notifyMatchFailure(
+ parallelOp,
+ "OpenMP dialect does not yet support loops with reductions");
+
+ // Replace SCF yield with OpenMP yield.
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToEnd(parallelOp.getBody());
+ assert(llvm::hasSingleElement(parallelOp.region()) &&
+ "expected scf.parallel to have one block");
+ rewriter.replaceOpWithNewOp<omp::YieldOp>(
+ parallelOp.getBody()->getTerminator(), ValueRange());
+ }
+
+ // Replace the loop.
+ auto loop = rewriter.create<omp::WsLoopOp>(
+ parallelOp.getLoc(), parallelOp.lowerBound(), parallelOp.upperBound(),
+ parallelOp.step());
+ rewriter.inlineRegionBefore(parallelOp.region(), loop.region(),
+ loop.region().begin());
+ rewriter.eraseOp(parallelOp);
+ return success();
+ }
+};
+
+/// Inserts OpenMP "parallel" operations around top-level SCF "parallel"
+/// operations in the given function. This is implemented as a direct IR
+/// modification rather than as a conversion pattern because it does not
+/// modify the top-level operation it matches, which is a requirement for
+/// rewrite patterns.
+//
+// TODO: consider creating nested parallel operations when necessary.
+static void insertOpenMPParallel(FuncOp func) {
+ // Collect top-level SCF "parallel" ops.
+ SmallVector<scf::ParallelOp, 4> topLevelParallelOps;
+ func.walk([&topLevelParallelOps](scf::ParallelOp parallelOp) {
+ // Ignore ops that are already within OpenMP parallel construct.
+ if (!parallelOp.getParentOfType<scf::ParallelOp>())
+ topLevelParallelOps.push_back(parallelOp);
+ });
+
+ // Wrap SCF ops into OpenMP "parallel" ops.
+ for (scf::ParallelOp parallelOp : topLevelParallelOps) {
+ OpBuilder builder(parallelOp);
+ auto omp = builder.create<omp::ParallelOp>(parallelOp.getLoc());
+ Block *block = builder.createBlock(&omp.getRegion());
+ builder.create<omp::TerminatorOp>(parallelOp.getLoc());
+ block->getOperations().splice(
+ block->begin(), parallelOp.getOperation()->getBlock()->getOperations(),
+ parallelOp.getOperation());
+ }
+}
+
+/// Applies the conversion patterns in the given function.
+static LogicalResult applyPatterns(FuncOp func) {
+ ConversionTarget target(*func.getContext());
+ target.addIllegalOp<scf::ParallelOp>();
+ target.addDynamicallyLegalOp<scf::YieldOp>(
+ [](scf::YieldOp op) { return !isa<scf::ParallelOp>(op.getParentOp()); });
+ target.addLegalDialect<omp::OpenMPDialect>();
+
+ OwningRewritePatternList patterns;
+ patterns.insert<ParallelOpLowering>(func.getContext());
+ FrozenRewritePatternList frozen(std::move(patterns));
+ return applyPartialConversion(func, target, frozen);
+}
+
+/// A pass converting SCF operations to OpenMP operations.
+struct SCFToOpenMPPass : public ConvertSCFToOpenMPBase<SCFToOpenMPPass> {
+ /// Pass entry point.
+ void runOnFunction() override {
+ insertOpenMPParallel(getFunction());
+ if (failed(applyPatterns(getFunction())))
+ signalPassFailure();
+ }
+};
+
+} // end namespace
+
+std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertSCFToOpenMPPass() {
+ return std::make_unique<SCFToOpenMPPass>();
+}
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 7ab4534..f4b76b6 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -37,6 +37,17 @@ void OpenMPDialect::initialize() {
// ParallelOp
//===----------------------------------------------------------------------===//
+void ParallelOp::build(OpBuilder &builder, OperationState &state,
+ ArrayRef<NamedAttribute> attributes) {
+ ParallelOp::build(
+ builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
+ /*default_val=*/nullptr, /*private_vars=*/ValueRange(),
+ /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(),
+ /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(),
+ /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr);
+ state.addAttributes(attributes);
+}
+
/// Parse a list of operands with types.
///
/// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
@@ -362,5 +373,22 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
return success();
}
+//===----------------------------------------------------------------------===//
+// WsLoopOp
+//===----------------------------------------------------------------------===//
+
+void WsLoopOp::build(OpBuilder &builder, OperationState &state,
+ ValueRange lowerBound, ValueRange upperBound,
+ ValueRange step, ArrayRef<NamedAttribute> attributes) {
+ build(builder, state, TypeRange(), lowerBound, upperBound, step,
+ /*private_vars=*/ValueRange(),
+ /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
+ /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
+ /*schedule_val=*/nullptr, /*schedule_chunk_var=*/nullptr,
+ /*collapse_val=*/nullptr,
+ /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr);
+ state.addAttributes(attributes);
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
new file mode 100644
index 0000000..466bd6a
--- /dev/null
+++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt -convert-scf-to-openmp %s | FileCheck %s
+
+// CHECK-LABEL: @parallel
+func @parallel(%arg0: index, %arg1: index, %arg2: index,
+ %arg3: index, %arg4: index, %arg5: index) {
+ // CHECK: omp.parallel {
+ // CHECK: "omp.wsloop"({{.*}}) ( {
+ scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
+ // CHECK: test.payload
+ "test.payload"(%i, %j) : (index, index) -> ()
+ // CHECK: omp.yield
+ // CHECK: }
+ }
+ // CHECK: omp.terminator
+ // CHECK: }
+ return
+}
+
+// CHECK-LABEL: @nested_loops
+func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
+ %arg3: index, %arg4: index, %arg5: index) {
+ // CHECK: omp.parallel {
+ // CHECK: "omp.wsloop"({{.*}}) ( {
+ // CHECK-NOT: omp.parallel
+ scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
+ // CHECK: "omp.wsloop"({{.*}}) ( {
+ scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
+ // CHECK: test.payload
+ "test.payload"(%i, %j) : (index, index) -> ()
+ // CHECK: omp.yield
+ // CHECK: }
+ }
+ // CHECK: omp.yield
+ // CHECK: }
+ }
+ // CHECK: omp.terminator
+ // CHECK: }
+ return
+}
+
+func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
+ %arg3: index, %arg4: index, %arg5: index) {
+ // CHECK: omp.parallel {
+ // CHECK: "omp.wsloop"({{.*}}) ( {
+ scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
+ // CHECK: test.payload1
+ "test.payload1"(%i) : (index) -> ()
+ // CHECK: omp.yield
+ // CHECK: }
+ }
+ // CHECK: omp.terminator
+ // CHECK: }
+
+ // CHECK: omp.parallel {
+ // CHECK: "omp.wsloop"({{.*}}) ( {
+ scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
+ // CHECK: test.payload2
+ "test.payload2"(%j) : (index) -> ()
+ // CHECK: omp.yield
+ // CHECK: }
+ }
+ // CHECK: omp.terminator
+ // CHECK: }
+ return
+}