diff options
author | Marius Brehler <marius.brehler@iml.fraunhofer.de> | 2024-03-07 11:34:11 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-07 11:34:11 +0100 |
commit | c40146c214a705a232848144d9412c8a7c73f0fe (patch) | |
tree | 50ea45a15488060aed47df9c7af9e68c3a0ec729 | |
parent | 6f54a54c6f5f644b4f4c79882154fd9737568c8e (diff) | |
download | llvm-c40146c214a705a232848144d9412c8a7c73f0fe.zip llvm-c40146c214a705a232848144d9412c8a7c73f0fe.tar.gz llvm-c40146c214a705a232848144d9412c8a7c73f0fe.tar.bz2 |
[mlir][EmitC] Add Arith to EmitC conversions (#84151)
This adds patterns and a pass to convert the Arith dialect to EmitC. For
now, this covers arithemtic binary ops operating on floating point
types.
It is not checked within the patterns whether the types, such as the
Tensor type, are supported in the respective EmitC operations. If
unsupported types should be converted, the conversion will fail anyway
because no legal EmitC operation can be created. This can clearly be
improved in a follow up, also resulting in better error messages.
Functions for such checks should not solely be used in the conversions
and should also be (re)used in the verifier.
-rw-r--r-- | mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h | 20 | ||||
-rw-r--r-- | mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h | 21 | ||||
-rw-r--r-- | mlir/include/mlir/Conversion/Passes.h | 1 | ||||
-rw-r--r-- | mlir/include/mlir/Conversion/Passes.td | 9 | ||||
-rw-r--r-- | mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp | 60 | ||||
-rw-r--r-- | mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp | 53 | ||||
-rw-r--r-- | mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt | 16 | ||||
-rw-r--r-- | mlir/lib/Conversion/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir | 14 | ||||
-rw-r--r-- | utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 27 |
10 files changed, 222 insertions, 0 deletions
diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h new file mode 100644 index 0000000..9cb4368 --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h @@ -0,0 +1,20 @@ +//===- ArithToEmitC.h - Arith to EmitC Patterns -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H +#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H + +namespace mlir { +class RewritePatternSet; +class TypeConverter; + +void populateArithToEmitCPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns); +} // namespace mlir + +#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h new file mode 100644 index 0000000..6b98fed --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h @@ -0,0 +1,21 @@ +//===- ArithToEmitCPass.h - Arith to EmitC Pass -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H +#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H + +#include <memory> + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTARITHTOEMITC +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 81f69210..f2aa4fb 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -13,6 +13,7 @@ #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h" +#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 94fc7a7..bd81cc6 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -134,6 +134,15 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> { } //===----------------------------------------------------------------------===// +// ArithToEmitC +//===----------------------------------------------------------------------===// + +def ConvertArithToEmitC : Pass<"convert-arith-to-emitc"> { + let summary = "Convert Arith dialect to EmitC dialect"; + let dependentDialects = ["emitc::EmitCDialect"]; +} + +//===----------------------------------------------------------------------===// // ArithToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp new file mode 100644 index 0000000..6909534 --- /dev/null +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -0,0 +1,60 @@ +//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns to convert the Arith dialect to the EmitC +// dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Conversion Patterns +//===----------------------------------------------------------------------===// + +namespace { +template <typename ArithOp, typename EmitCOp> +class ArithOpConversion final : public OpConversionPattern<ArithOp> { +public: + using OpConversionPattern<ArithOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(), + adaptor.getOperands()); + + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Pattern population +//===----------------------------------------------------------------------===// + +void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + + // clang-format off + patterns.add< + ArithOpConversion<arith::AddFOp, emitc::AddOp>, + ArithOpConversion<arith::DivFOp, emitc::DivOp>, + ArithOpConversion<arith::MulFOp, emitc::MulOp>, + ArithOpConversion<arith::SubFOp, emitc::SubOp> + >(typeConverter, ctx); + // clang-format on +} diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp new file mode 100644 index 0000000..b377c06 --- /dev/null +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp @@ -0,0 +1,53 @@ +//===- ArithToEmitCPass.cpp - Arith to EmitC Pass ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert the Arith dialect to the EmitC +// dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h" + +#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTARITHTOEMITC +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct ConvertArithToEmitC + : public impl::ConvertArithToEmitCBase<ConvertArithToEmitC> { + void runOnOperation() override; +}; +} // namespace + +void ConvertArithToEmitC::runOnOperation() { + ConversionTarget target(getContext()); + + target.addLegalDialect<emitc::EmitCDialect>(); + target.addIllegalDialect<arith::ArithDialect>(); + target.addLegalOp<arith::ConstantOp>(); + + RewritePatternSet patterns(&getContext()); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + + populateArithToEmitCPatterns(typeConverter, patterns); + + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt new file mode 100644 index 0000000..a3784f4 --- /dev/null +++ b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_conversion_library(MLIRArithToEmitC + ArithToEmitC.cpp + ArithToEmitCPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToEmitC + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIREmitCDialect + MLIRPass + MLIRTransformUtils + ) diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 9e421f7..8219cf9 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(AMDGPUToROCDL) add_subdirectory(ArithCommon) add_subdirectory(ArithToAMDGPU) add_subdirectory(ArithToArmSME) +add_subdirectory(ArithToEmitC) add_subdirectory(ArithToLLVM) add_subdirectory(ArithToSPIRV) add_subdirectory(ArmNeon2dToIntr) diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir new file mode 100644 index 0000000..6a56474 --- /dev/null +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt -convert-arith-to-emitc %s | FileCheck %s + +func.func @arith_ops(%arg0: f32, %arg1: f32) { + // CHECK: [[V0:[^ ]*]] = emitc.add %arg0, %arg1 : (f32, f32) -> f32 + %0 = arith.addf %arg0, %arg1 : f32 + // CHECK: [[V1:[^ ]*]] = emitc.div %arg0, %arg1 : (f32, f32) -> f32 + %1 = arith.divf %arg0, %arg1 : f32 + // CHECK: [[V2:[^ ]*]] = emitc.mul %arg0, %arg1 : (f32, f32) -> f32 + %2 = arith.mulf %arg0, %arg1 : f32 + // CHECK: [[V3:[^ ]*]] = emitc.sub %arg0, %arg1 : (f32, f32) -> f32 + %3 = arith.subf %arg0, %arg1 : f32 + + return +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 86b38eb..9d6ca4e 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4014,6 +4014,7 @@ cc_library( ":AffineToStandard", ":ArithToAMDGPU", ":ArithToArmSME", + ":ArithToEmitC", ":ArithToLLVM", ":ArithToSPIRV", ":ArmNeon2dToIntr", @@ -8163,6 +8164,32 @@ cc_library( ) cc_library( + name = "ArithToEmitC", + srcs = glob([ + "lib/Conversion/ArithToEmitC/*.cpp", + "lib/Conversion/ArithToEmitC/*.h", + ]), + hdrs = glob([ + "include/mlir/Conversion/ArithToEmitC/*.h", + ]), + includes = [ + "include", + "lib/Conversion/ArithToEmitC", + ], + deps = [ + ":ArithDialect", + ":ConversionPassIncGen", + ":EmitCDialect", + ":IR", + ":Pass", + ":Support", + ":TransformUtils", + ":Transforms", + "//llvm:Support", + ], +) + +cc_library( name = "ArithToLLVM", srcs = glob(["lib/Conversion/ArithToLLVM/*.cpp"]), hdrs = glob(["include/mlir/Conversion/ArithToLLVM/*.h"]), |