aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarius Brehler <marius.brehler@iml.fraunhofer.de>2024-03-07 11:34:11 +0100
committerGitHub <noreply@github.com>2024-03-07 11:34:11 +0100
commitc40146c214a705a232848144d9412c8a7c73f0fe (patch)
tree50ea45a15488060aed47df9c7af9e68c3a0ec729
parent6f54a54c6f5f644b4f4c79882154fd9737568c8e (diff)
downloadllvm-c40146c214a705a232848144d9412c8a7c73f0fe.zip
llvm-c40146c214a705a232848144d9412c8a7c73f0fe.tar.gz
llvm-c40146c214a705a232848144d9412c8a7c73f0fe.tar.bz2
[mlir][EmitC] Add Arith to EmitC conversions (#84151)
This adds patterns and a pass to convert the Arith dialect to EmitC. For now, this covers arithemtic binary ops operating on floating point types. It is not checked within the patterns whether the types, such as the Tensor type, are supported in the respective EmitC operations. If unsupported types should be converted, the conversion will fail anyway because no legal EmitC operation can be created. This can clearly be improved in a follow up, also resulting in better error messages. Functions for such checks should not solely be used in the conversions and should also be (re)used in the verifier.
-rw-r--r--mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h20
-rw-r--r--mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h21
-rw-r--r--mlir/include/mlir/Conversion/Passes.h1
-rw-r--r--mlir/include/mlir/Conversion/Passes.td9
-rw-r--r--mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp60
-rw-r--r--mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp53
-rw-r--r--mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt16
-rw-r--r--mlir/lib/Conversion/CMakeLists.txt1
-rw-r--r--mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir14
-rw-r--r--utils/bazel/llvm-project-overlay/mlir/BUILD.bazel27
10 files changed, 222 insertions, 0 deletions
diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
new file mode 100644
index 0000000..9cb4368
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
@@ -0,0 +1,20 @@
+//===- ArithToEmitC.h - Arith to EmitC Patterns -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
+#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
+
+namespace mlir {
+class RewritePatternSet;
+class TypeConverter;
+
+void populateArithToEmitCPatterns(TypeConverter &typeConverter,
+ RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h
new file mode 100644
index 0000000..6b98fed
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h
@@ -0,0 +1,21 @@
+//===- ArithToEmitCPass.h - Arith to EmitC Pass -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H
+#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTARITHTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 81f69210..f2aa4fb 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -13,6 +13,7 @@
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 94fc7a7..bd81cc6 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -134,6 +134,15 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
}
//===----------------------------------------------------------------------===//
+// ArithToEmitC
+//===----------------------------------------------------------------------===//
+
+def ConvertArithToEmitC : Pass<"convert-arith-to-emitc"> {
+ let summary = "Convert Arith dialect to EmitC dialect";
+ let dependentDialects = ["emitc::EmitCDialect"];
+}
+
+//===----------------------------------------------------------------------===//
// ArithToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
new file mode 100644
index 0000000..6909534
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -0,0 +1,60 @@
+//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to convert the Arith dialect to the EmitC
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+template <typename ArithOp, typename EmitCOp>
+class ArithOpConversion final : public OpConversionPattern<ArithOp> {
+public:
+ using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
+ adaptor.getOperands());
+
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern population
+//===----------------------------------------------------------------------===//
+
+void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
+ RewritePatternSet &patterns) {
+ MLIRContext *ctx = patterns.getContext();
+
+ // clang-format off
+ patterns.add<
+ ArithOpConversion<arith::AddFOp, emitc::AddOp>,
+ ArithOpConversion<arith::DivFOp, emitc::DivOp>,
+ ArithOpConversion<arith::MulFOp, emitc::MulOp>,
+ ArithOpConversion<arith::SubFOp, emitc::SubOp>
+ >(typeConverter, ctx);
+ // clang-format on
+}
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
new file mode 100644
index 0000000..b377c06
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
@@ -0,0 +1,53 @@
+//===- ArithToEmitCPass.cpp - Arith to EmitC Pass ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert the Arith dialect to the EmitC
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTARITHTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct ConvertArithToEmitC
+ : public impl::ConvertArithToEmitCBase<ConvertArithToEmitC> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertArithToEmitC::runOnOperation() {
+ ConversionTarget target(getContext());
+
+ target.addLegalDialect<emitc::EmitCDialect>();
+ target.addIllegalDialect<arith::ArithDialect>();
+ target.addLegalOp<arith::ConstantOp>();
+
+ RewritePatternSet patterns(&getContext());
+
+ TypeConverter typeConverter;
+ typeConverter.addConversion([](Type type) { return type; });
+
+ populateArithToEmitCPatterns(typeConverter, patterns);
+
+ if (failed(
+ applyPartialConversion(getOperation(), target, std::move(patterns))))
+ signalPassFailure();
+}
diff --git a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
new file mode 100644
index 0000000..a3784f4
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_conversion_library(MLIRArithToEmitC
+ ArithToEmitC.cpp
+ ArithToEmitCPass.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToEmitC
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ MLIREmitCDialect
+ MLIRPass
+ MLIRTransformUtils
+ )
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 9e421f7..8219cf9 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -3,6 +3,7 @@ add_subdirectory(AMDGPUToROCDL)
add_subdirectory(ArithCommon)
add_subdirectory(ArithToAMDGPU)
add_subdirectory(ArithToArmSME)
+add_subdirectory(ArithToEmitC)
add_subdirectory(ArithToLLVM)
add_subdirectory(ArithToSPIRV)
add_subdirectory(ArmNeon2dToIntr)
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
new file mode 100644
index 0000000..6a56474
--- /dev/null
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt -convert-arith-to-emitc %s | FileCheck %s
+
+func.func @arith_ops(%arg0: f32, %arg1: f32) {
+ // CHECK: [[V0:[^ ]*]] = emitc.add %arg0, %arg1 : (f32, f32) -> f32
+ %0 = arith.addf %arg0, %arg1 : f32
+ // CHECK: [[V1:[^ ]*]] = emitc.div %arg0, %arg1 : (f32, f32) -> f32
+ %1 = arith.divf %arg0, %arg1 : f32
+ // CHECK: [[V2:[^ ]*]] = emitc.mul %arg0, %arg1 : (f32, f32) -> f32
+ %2 = arith.mulf %arg0, %arg1 : f32
+ // CHECK: [[V3:[^ ]*]] = emitc.sub %arg0, %arg1 : (f32, f32) -> f32
+ %3 = arith.subf %arg0, %arg1 : f32
+
+ return
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 86b38eb..9d6ca4e 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4014,6 +4014,7 @@ cc_library(
":AffineToStandard",
":ArithToAMDGPU",
":ArithToArmSME",
+ ":ArithToEmitC",
":ArithToLLVM",
":ArithToSPIRV",
":ArmNeon2dToIntr",
@@ -8163,6 +8164,32 @@ cc_library(
)
cc_library(
+ name = "ArithToEmitC",
+ srcs = glob([
+ "lib/Conversion/ArithToEmitC/*.cpp",
+ "lib/Conversion/ArithToEmitC/*.h",
+ ]),
+ hdrs = glob([
+ "include/mlir/Conversion/ArithToEmitC/*.h",
+ ]),
+ includes = [
+ "include",
+ "lib/Conversion/ArithToEmitC",
+ ],
+ deps = [
+ ":ArithDialect",
+ ":ConversionPassIncGen",
+ ":EmitCDialect",
+ ":IR",
+ ":Pass",
+ ":Support",
+ ":TransformUtils",
+ ":Transforms",
+ "//llvm:Support",
+ ],
+)
+
+cc_library(
name = "ArithToLLVM",
srcs = glob(["lib/Conversion/ArithToLLVM/*.cpp"]),
hdrs = glob(["include/mlir/Conversion/ArithToLLVM/*.h"]),