diff options
author | Valentin Clement <clementval@gmail.com> | 2024-01-30 11:29:39 -0800 |
---|---|---|
committer | Valentin Clement <clementval@gmail.com> | 2024-02-05 12:34:38 -0800 |
commit | fa7d0d3e35f74486ccb0faa88ec706defe7dd2d2 (patch) | |
tree | 2fdb50e7fb8710834ffed54dcbe5708f85adf896 | |
parent | dd22140e21f2ef51cf031354966a3d41c191c6e7 (diff) | |
download | llvm-fa7d0d3e35f74486ccb0faa88ec706defe7dd2d2.zip llvm-fa7d0d3e35f74486ccb0faa88ec706defe7dd2d2.tar.gz llvm-fa7d0d3e35f74486ccb0faa88ec706defe7dd2d2.tar.bz2 |
[mlir][openacc] Add legalize data pass for compute operation (#80351)
This patch adds a simple pass to replace the uses inside compute operation. It
replaces the `varPtr` values with their corresponding `accPtr` values gathered
through the dataClauseOperands.
private and reductions variables are not included in this pass since they will
normally be replace when they are materialized.
-rw-r--r-- | flang/include/flang/Optimizer/Support/InitFIR.h | 2 | ||||
-rw-r--r-- | flang/test/Fir/OpenACC/legalize-data.fir | 24 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt | 2 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt | 5 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h | 40 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td | 28 | ||||
-rw-r--r-- | mlir/include/mlir/InitAllPasses.h | 2 | ||||
-rw-r--r-- | mlir/lib/Dialect/OpenACC/CMakeLists.txt | 22 | ||||
-rw-r--r-- | mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt | 20 | ||||
-rw-r--r-- | mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt | 20 | ||||
-rw-r--r-- | mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp | 72 | ||||
-rw-r--r-- | mlir/test/Dialect/OpenACC/legalize-data.mlir | 88 |
12 files changed, 305 insertions, 20 deletions
diff --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h index 8c47ad3..b5c4169 100644 --- a/flang/include/flang/Optimizer/Support/InitFIR.h +++ b/flang/include/flang/Optimizer/Support/InitFIR.h @@ -19,6 +19,7 @@ #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" #include "mlir/InitAllDialects.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" @@ -74,6 +75,7 @@ inline void loadDialects(mlir::MLIRContext &context) { /// Register the standard passes we use. This comes from registerAllPasses(), /// but is a smaller set since we aren't using many of the passes found there. inline void registerMLIRPassesForFortranTools() { + mlir::acc::registerOpenACCPasses(); mlir::registerCanonicalizerPass(); mlir::registerCSEPass(); mlir::affine::registerAffineLoopFusionPass(); diff --git a/flang/test/Fir/OpenACC/legalize-data.fir b/flang/test/Fir/OpenACC/legalize-data.fir new file mode 100644 index 0000000..3b86954 --- /dev/null +++ b/flang/test/Fir/OpenACC/legalize-data.fir @@ -0,0 +1,24 @@ +// RUN: fir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s + +func.func @_QPsub1(%arg0: !fir.ref<i32> {fir.bindc_name = "i"}) { + %0:2 = hlfir.declare %arg0 {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>) + %1 = acc.copyin varPtr(%0#0 : !fir.ref<i32>) -> !fir.ref<i32> {dataClause = #acc<data_clause acc_copy>, name = "i"} + acc.parallel dataOperands(%1 : !fir.ref<i32>) { + %c0_i32 = arith.constant 0 : i32 + hlfir.assign %c0_i32 to %0#0 : i32, !fir.ref<i32> + acc.yield + } + acc.copyout accPtr(%1 : !fir.ref<i32>) to varPtr(%0#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"} + return +} + +// CHECK-LABEL: func.func @_QPsub1 +// CHECK-SAME: (%[[ARG0:.*]]: !fir.ref<i32> {fir.bindc_name = "i"}) +// CHECK: %[[I:.*]]:2 = hlfir.declare %[[ARG0]] {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>) +// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr(%[[I]]#0 : !fir.ref<i32>) -> !fir.ref<i32> {dataClause = #acc<data_clause acc_copy>, name = "i"} +// CHECK: acc.parallel dataOperands(%[[COPYIN]] : !fir.ref<i32>) { +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: hlfir.assign %c0{{.*}} to %[[COPYIN]] : i32, !fir.ref<i32> +// CHECK: acc.yield +// CHECK: } +// CHECK: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<i32>) to varPtr(%[[I]]#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"} diff --git a/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt index 56ba297..8a4b1c7 100644 --- a/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(Transforms) + set(LLVM_TARGET_DEFINITIONS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend/OpenACC/ACC.td) mlir_tablegen(AccCommon.td --gen-directive-decl --directives-dialect=OpenACC) add_public_tablegen_target(acc_common_td) diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt new file mode 100644 index 0000000..ddbd583 --- /dev/null +++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name OpenACC) +add_public_tablegen_target(MLIROpenACCPassIncGen) + +add_mlir_doc(Passes OpenACCPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h new file mode 100644 index 0000000..5a11056 --- /dev/null +++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h @@ -0,0 +1,40 @@ +//===- Passes.h - OpenACC Passes Construction and Registration ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES_H +#define MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES_H + +#include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h" +#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h" +#include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h" +#include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h" +#include "mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h" +#include "mlir/Pass/Pass.h" + +#define GEN_PASS_DECL +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" + +namespace mlir { + +namespace func { +class FuncOp; +} // namespace func + +namespace acc { + +/// Create a pass to replace ssa values in region with device/host values. +std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeDataInRegion(); + +/// Generate the code for registering conversion passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" + +} // namespace acc +} // namespace mlir + +#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES_H diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td new file mode 100644 index 0000000..abbc277 --- /dev/null +++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td @@ -0,0 +1,28 @@ +//===-- Passes.td - OpenACC pass definition file -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES +#define MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def LegalizeDataInRegion : Pass<"openacc-legalize-data", "mlir::func::FuncOp"> { + let summary = "Legalize the data in the compute region"; + let description = [{ + This pass replace uses of varPtr in the compute region with their accPtr + gathered from the data clause operands. + }]; + let options = [ + Option<"hostToDevice", "host-to-device", "bool", "true", + "Replace varPtr uses with accPtr if true. Replace accPtr uses with " + "varPtr if false"> + ]; + let constructor = "::mlir::acc::createLegalizeDataInRegion()"; +} + +#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h index 28dc3cc..e289216 100644 --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -34,6 +34,7 @@ #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Mesh/Transforms/Passes.h" #include "mlir/Dialect/NVGPU/Transforms/Passes.h" +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" #include "mlir/Dialect/SCF/Transforms/Passes.h" #include "mlir/Dialect/SPIRV/Transforms/Passes.h" #include "mlir/Dialect/Shape/Transforms/Passes.h" @@ -64,6 +65,7 @@ inline void registerAllPasses() { registerConversionPasses(); // Dialect passes + acc::registerOpenACCPasses(); affine::registerAffinePasses(); amdgpu::registerAMDGPUPasses(); registerAsyncPasses(); diff --git a/mlir/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/CMakeLists.txt index 2728524..9f57627 100644 --- a/mlir/lib/Dialect/OpenACC/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenACC/CMakeLists.txt @@ -1,20 +1,2 @@ -add_mlir_dialect_library(MLIROpenACCDialect - IR/OpenACC.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC - - DEPENDS - MLIROpenACCOpsIncGen - MLIROpenACCEnumsIncGen - MLIROpenACCAttributesIncGen - MLIROpenACCOpsInterfacesIncGen - MLIROpenACCTypeInterfacesIncGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRLLVMDialect - MLIRMemRefDialect - MLIROpenACCMPCommon - ) - +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt new file mode 100644 index 0000000..b802de1 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_dialect_library(MLIROpenACCDialect + OpenACC.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC + + DEPENDS + MLIROpenACCOpsIncGen + MLIROpenACCEnumsIncGen + MLIROpenACCAttributesIncGen + MLIROpenACCOpsInterfacesIncGen + MLIROpenACCTypeInterfacesIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMDialect + MLIRMemRefDialect + MLIROpenACCMPCommon + ) + diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt new file mode 100644 index 0000000..b7b9cf8 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_dialect_library(MLIROpenACCTransforms + LegalizeData.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC + + DEPENDS + MLIROpenACCPassIncGen + MLIROpenACCOpsIncGen + MLIROpenACCEnumsIncGen + MLIROpenACCAttributesIncGen + MLIROpenACCOpsInterfacesIncGen + MLIROpenACCTypeInterfacesIncGen + + LINK_LIBS PUBLIC + MLIROpenACCDialect + MLIRIR + MLIRPass + MLIRTransforms +) diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp new file mode 100644 index 0000000..ef44a0e --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp @@ -0,0 +1,72 @@ +//===- LegalizeData.cpp - -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/RegionUtils.h" + +namespace mlir { +namespace acc { +#define GEN_PASS_DEF_LEGALIZEDATAINREGION +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" +} // namespace acc +} // namespace mlir + +using namespace mlir; + +namespace { + +template <typename Op> +static void collectAndReplaceInRegion(Op &op, bool hostToDevice) { + llvm::SmallVector<std::pair<Value, Value>> values; + for (auto operand : op.getDataClauseOperands()) { + Value varPtr = acc::getVarPtr(operand.getDefiningOp()); + Value accPtr = acc::getAccPtr(operand.getDefiningOp()); + if (varPtr && accPtr) { + if (hostToDevice) + values.push_back({varPtr, accPtr}); + else + values.push_back({accPtr, varPtr}); + } + } + + for (auto p : values) + replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion()); +} + +struct LegalizeDataInRegion + : public acc::impl::LegalizeDataInRegionBase<LegalizeDataInRegion> { + + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + bool replaceHostVsDevice = this->hostToDevice.getValue(); + + funcOp.walk([&](Operation *op) { + if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op)) + return; + + if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) { + collectAndReplaceInRegion(parallelOp, replaceHostVsDevice); + } else if (auto serialOp = dyn_cast<acc::SerialOp>(*op)) { + collectAndReplaceInRegion(serialOp, replaceHostVsDevice); + } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) { + collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice); + } + }); + } +}; + +} // end anonymous namespace + +std::unique_ptr<OperationPass<func::FuncOp>> +mlir::acc::createLegalizeDataInRegion() { + return std::make_unique<LegalizeDataInRegion>(); +} diff --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir new file mode 100644 index 0000000..f985741 --- /dev/null +++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s --check-prefixes=CHECK,DEVICE +// RUN: mlir-opt -split-input-file --openacc-legalize-data=host-to-device=false %s | FileCheck %s --check-prefixes=CHECK,HOST + +func.func @test(%a: memref<10xf32>, %i : index) { + %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32> + acc.parallel dataOperands(%create : memref<10xf32>) { + %ci = memref.load %a[%i] : memref<10xf32> + acc.yield + } + return +} + +// CHECK-LABEL: func.func @test +// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index) +// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32> +// CHECK: acc.parallel dataOperands(%[[CREATE]] : memref<10xf32>) { +// DEVICE: %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32> +// HOST: %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32> +// CHECK: acc.yield +// CHECK: } + +// ----- + +func.func @test(%a: memref<10xf32>, %i : index) { + %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32> + acc.serial dataOperands(%create : memref<10xf32>) { + %ci = memref.load %a[%i] : memref<10xf32> + acc.yield + } + return +} + +// CHECK-LABEL: func.func @test +// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index) +// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32> +// CHECK: acc.serial dataOperands(%[[CREATE]] : memref<10xf32>) { +// DEVICE: %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32> +// HOST: %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32> +// CHECK: acc.yield +// CHECK: } + +// ----- + +func.func @test(%a: memref<10xf32>, %i : index) { + %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32> + acc.kernels dataOperands(%create : memref<10xf32>) { + %ci = memref.load %a[%i] : memref<10xf32> + acc.terminator + } + return +} + +// CHECK-LABEL: func.func @test +// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index) +// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32> +// CHECK: acc.kernels dataOperands(%[[CREATE]] : memref<10xf32>) { +// DEVICE: %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32> +// HOST: %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32> +// CHECK: acc.terminator +// CHECK: } + +// ----- + +func.func @test(%a: memref<10xf32>) { + %lb = arith.constant 0 : index + %st = arith.constant 1 : index + %c10 = arith.constant 10 : index + %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32> + acc.parallel dataOperands(%create : memref<10xf32>) { + acc.loop (%i : index) = (%lb : index) to (%c10 : index) step (%st : index) { + %ci = memref.load %a[%i] : memref<10xf32> + acc.yield + } + acc.yield + } + return +} + +// CHECK: func.func @test +// CHECK-SAME: (%[[A:.*]]: memref<10xf32>) +// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32> +// CHECK: acc.parallel dataOperands(%[[CREATE]] : memref<10xf32>) { +// CHECK: acc.loop (%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) { +// DEVICE: %{{.*}} = memref.load %[[CREATE:.*]][%[[I]]] : memref<10xf32> +// CHECK: acc.yield +// CHECK: } +// CHECK: acc.yield +// CHECK: } |