diff options
author | Tomer Solomon <tomsol2009@gmail.com> | 2025-01-20 10:26:41 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-01-20 09:26:41 +0100 |
commit | c0055ec434cbb132d7776f8b4c39e99b69fa97ea (patch) | |
tree | 6846b0a651ac10a9b28d1c77e552781cb186e342 | |
parent | 7a77f14c0abfbecbfb800ea8d974e66d81ee516a (diff) | |
download | llvm-c0055ec434cbb132d7776f8b4c39e99b69fa97ea.zip llvm-c0055ec434cbb132d7776f8b4c39e99b69fa97ea.tar.gz llvm-c0055ec434cbb132d7776f8b4c39e99b69fa97ea.tar.bz2 |
[mlir][EmitC] Add MathToEmitC pass for math function lowering to EmitC (#113799)
This commit introduces a new MathToEmitC conversion pass that lowers
selected math operations from the Math dialect to the emitc.call_opaque
operation in the EmitC dialect.
**Supported Math Operations:**
The following operations are converted:
- math.floor -> emitc.call_opaque<"floor">
- math.round -> emitc.call_opaque<"round">
- math.exp -> emitc.call_opaque<"exp">
- math.cos -> emitc.call_opaque<"cos">
- math.sin -> emitc.call_opaque<"sin">
- math.acos -> emitc.call_opaque<"acos">
- math.asin -> emitc.call_opaque<"asin">
- math.atan2 -> emitc.call_opaque<"atan2">
- math.ceil -> emitc.call_opaque<"ceil">
- math.absf -> emitc.call_opaque<"fabs">
- math.powf -> emitc.call_opaque<"pow">
**Target Language Standards:**
The pass supports targeting different language standards:
- C99: Generates calls with suffixes (e.g., floorf, fabsf) for
single-precision floats.
- CPP11: Prepends std:: to functions (e.g., std::floor, std::fabs).
**Design Decisions:**
The pass uses emitc.call_opaque instead of emitc.call to better emulate
C-style function overloading.
emitc.call_opaque does not require a unique type signature, making it
more suitable for operations like <math.h> functions that may be
overloaded for different types.
This design choice ensures compatibility with C/C++ conventions.
-rw-r--r-- | mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h | 25 | ||||
-rw-r--r-- | mlir/include/mlir/Conversion/MathToEmitC/MathToEmitCPass.h | 21 | ||||
-rw-r--r-- | mlir/include/mlir/Conversion/Passes.h | 1 | ||||
-rw-r--r-- | mlir/include/mlir/Conversion/Passes.td | 22 | ||||
-rw-r--r-- | mlir/lib/Conversion/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/lib/Conversion/MathToEmitC/CMakeLists.txt | 19 | ||||
-rw-r--r-- | mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp | 85 | ||||
-rw-r--r-- | mlir/lib/Conversion/MathToEmitC/MathToEmitCPass.cpp | 53 | ||||
-rw-r--r-- | mlir/test/Conversion/MathToEmitC/math-to-emitc-failed.mlir | 23 | ||||
-rw-r--r-- | mlir/test/Conversion/MathToEmitC/math-to-emitc.mlir | 112 | ||||
-rw-r--r-- | utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 22 |
11 files changed, 384 insertions, 0 deletions
diff --git a/mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h b/mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h new file mode 100644 index 0000000..0fc33bf --- /dev/null +++ b/mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h @@ -0,0 +1,25 @@ +//===- MathToEmitC.h - Math to EmitC Patterns -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H +#define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H +#include "mlir/Dialect/EmitC/IR/EmitC.h" +namespace mlir { +class RewritePatternSet; +namespace emitc { + +/// Enum to specify the language target for EmitC code generation. +enum class LanguageTarget { c99, cpp11 }; + +} // namespace emitc + +void populateConvertMathToEmitCPatterns(RewritePatternSet &patterns, + emitc::LanguageTarget languageTarget); +} // namespace mlir + +#endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H diff --git a/mlir/include/mlir/Conversion/MathToEmitC/MathToEmitCPass.h b/mlir/include/mlir/Conversion/MathToEmitC/MathToEmitCPass.h new file mode 100644 index 0000000..c3861db --- /dev/null +++ b/mlir/include/mlir/Conversion/MathToEmitC/MathToEmitCPass.h @@ -0,0 +1,21 @@ +//===- MathToEmitCPass.h - Math to EmitC Pass -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H +#define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H + +#include "mlir/Conversion/MathToEmitC/MathToEmitC.h" +#include <memory> +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTMATHTOEMITC +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +#endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index b577aa8..6a564e9 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -43,6 +43,7 @@ #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h" #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" +#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h" #include "mlir/Conversion/MathToFuncs/MathToFuncs.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MathToLibm/MathToLibm.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 0f42ffb..3aea30b 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -791,6 +791,28 @@ def ConvertMathToSPIRV : Pass<"convert-math-to-spirv"> { } //===----------------------------------------------------------------------===// +// MathToEmitC +//===----------------------------------------------------------------------===// + +def ConvertMathToEmitC : Pass<"convert-math-to-emitc"> { + let summary = "Convert some Math operations to EmitC call_opaque operations"; + let description = [{ + This pass converts supported Math ops to `call_opaque` ops targeting libc/libm + functions. Unlike convert-math-to-funcs pass, converting to `call_opaque` ops + allows to overload the same function with different argument types. + }]; + let dependentDialects = ["emitc::EmitCDialect"]; + let options = [ + Option<"languageTarget", "language-target", "::mlir::emitc::LanguageTarget", + /*default=*/"::mlir::emitc::LanguageTarget::c99", "Select the language standard target for callees (c99 or cpp11).", + [{::llvm::cl::values( + clEnumValN(::mlir::emitc::LanguageTarget::c99, "c99", "c99"), + clEnumValN(::mlir::emitc::LanguageTarget::cpp11, "cpp11", "cpp11") + )}]> + ]; +} + +//===----------------------------------------------------------------------===// // MathToFuncs //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 62461c0..791e94e 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -33,6 +33,7 @@ add_subdirectory(IndexToLLVM) add_subdirectory(IndexToSPIRV) add_subdirectory(LinalgToStandard) add_subdirectory(LLVMCommon) +add_subdirectory(MathToEmitC) add_subdirectory(MathToFuncs) add_subdirectory(MathToLibm) add_subdirectory(MathToLLVM) diff --git a/mlir/lib/Conversion/MathToEmitC/CMakeLists.txt b/mlir/lib/Conversion/MathToEmitC/CMakeLists.txt new file mode 100644 index 0000000..12a99c31 --- /dev/null +++ b/mlir/lib/Conversion/MathToEmitC/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRMathToEmitC + MathToEmitC.cpp + MathToEmitCPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToEmitC + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIREmitCDialect + MLIRMathDialect + MLIRPass + MLIRTransformUtils +) diff --git a/mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp b/mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp new file mode 100644 index 0000000..9a0651a --- /dev/null +++ b/mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp @@ -0,0 +1,85 @@ +//===- MathToEmitC.cpp - Math to EmitC Patterns -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MathToEmitC/MathToEmitC.h" + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { +template <typename OpType> +class LowerToEmitCCallOpaque : public OpRewritePattern<OpType> { + std::string calleeStr; + emitc::LanguageTarget languageTarget; + +public: + LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr, + emitc::LanguageTarget languageTarget) + : OpRewritePattern<OpType>(context), calleeStr(std::move(calleeStr)), + languageTarget(languageTarget) {} + + LogicalResult matchAndRewrite(OpType op, + PatternRewriter &rewriter) const override; +}; + +template <typename OpType> +LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite( + OpType op, PatternRewriter &rewriter) const { + if (!llvm::all_of(op->getOperandTypes(), + llvm::IsaPred<Float32Type, Float64Type>) || + !llvm::all_of(op->getResultTypes(), + llvm::IsaPred<Float32Type, Float64Type>)) + return rewriter.notifyMatchFailure( + op.getLoc(), + "expected all operands and results to be of type f32 or f64"); + std::string modifiedCalleeStr = calleeStr; + if (languageTarget == emitc::LanguageTarget::cpp11) { + modifiedCalleeStr = "std::" + calleeStr; + } else if (languageTarget == emitc::LanguageTarget::c99) { + auto operandType = op->getOperandTypes()[0]; + if (operandType.isF32()) + modifiedCalleeStr = calleeStr + "f"; + } + rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>( + op, op.getType(), modifiedCalleeStr, op->getOperands()); + return success(); +} + +} // namespace + +// Populates patterns to replace `math` operations with `emitc.call_opaque`, +// using function names consistent with those in <math.h>. +void mlir::populateConvertMathToEmitCPatterns( + RewritePatternSet &patterns, emitc::LanguageTarget languageTarget) { + auto *context = patterns.getContext(); + patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floor", + languageTarget); + patterns.insert<LowerToEmitCCallOpaque<math::RoundOp>>(context, "round", + languageTarget); + patterns.insert<LowerToEmitCCallOpaque<math::ExpOp>>(context, "exp", + languageTarget); + patterns.insert<LowerToEmitCCallOpaque<math::CosOp>>(context, "cos", + languageTarget); + patterns.insert<LowerToEmitCCallOpaque<math::SinOp>>(context, "sin", + languageTarget); + patterns.insert<LowerToEmitCCallOpaque<math::AcosOp>>(context, "acos", + languageTarget); + patterns.insert<LowerToEmitCCallOpaque<math::AsinOp>>(context, "asin", + languageTarget); + patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2", + languageTarget); + patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceil", + languageTarget); + patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabs", + languageTarget); + patterns.insert<LowerToEmitCCallOpaque<math::PowFOp>>(context, "pow", + languageTarget); +} diff --git a/mlir/lib/Conversion/MathToEmitC/MathToEmitCPass.cpp b/mlir/lib/Conversion/MathToEmitC/MathToEmitCPass.cpp new file mode 100644 index 0000000..87a2764 --- /dev/null +++ b/mlir/lib/Conversion/MathToEmitC/MathToEmitCPass.cpp @@ -0,0 +1,53 @@ +//===- MathToEmitCPass.cpp - Math to EmitC Pass -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert the Math dialect to the EmitC dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h" +#include "mlir/Conversion/MathToEmitC/MathToEmitC.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTMATHTOEMITC +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +namespace { + +// Replaces Math operations with `emitc.call_opaque` operations. +struct ConvertMathToEmitC + : public impl::ConvertMathToEmitCBase<ConvertMathToEmitC> { + using ConvertMathToEmitCBase::ConvertMathToEmitCBase; + +public: + void runOnOperation() final; +}; + +} // namespace + +void ConvertMathToEmitC::runOnOperation() { + ConversionTarget target(getContext()); + target.addLegalOp<emitc::CallOpaqueOp>(); + + target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundOp, math::CosOp, + math::SinOp, math::Atan2Op, math::CeilOp, math::AcosOp, + math::AsinOp, math::AbsFOp, math::PowFOp>(); + + RewritePatternSet patterns(&getContext()); + populateConvertMathToEmitCPatterns(patterns, languageTarget); + + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/test/Conversion/MathToEmitC/math-to-emitc-failed.mlir b/mlir/test/Conversion/MathToEmitC/math-to-emitc-failed.mlir new file mode 100644 index 0000000..f1de97c --- /dev/null +++ b/mlir/test/Conversion/MathToEmitC/math-to-emitc-failed.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt -split-input-file -convert-math-to-emitc -verify-diagnostics %s + +func.func @unsupported_tensor_type(%arg0 : tensor<4xf32>) -> tensor<4xf32> { +// expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}} + %0 = math.absf %arg0 : tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- + +func.func @unsupported_f16_type(%arg0 : f16) -> f16 { +// expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}} + %0 = math.absf %arg0 : f16 + return %0 : f16 +} + +// ----- + +func.func @unsupported_f128_type(%arg0 : f128) -> f128 { +// expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}} + %0 = math.absf %arg0 : f128 + return %0 : f128 +} diff --git a/mlir/test/Conversion/MathToEmitC/math-to-emitc.mlir b/mlir/test/Conversion/MathToEmitC/math-to-emitc.mlir new file mode 100644 index 0000000..111d93d --- /dev/null +++ b/mlir/test/Conversion/MathToEmitC/math-to-emitc.mlir @@ -0,0 +1,112 @@ +// RUN: mlir-opt -convert-math-to-emitc=language-target=c99 %s | FileCheck %s --check-prefix=c99 +// RUN: mlir-opt -convert-math-to-emitc=language-target=cpp11 %s | FileCheck %s --check-prefix=cpp11 + +func.func @absf(%arg0: f32, %arg1: f64) { + // c99: emitc.call_opaque "fabsf" + // c99-NEXT: emitc.call_opaque "fabs" + // cpp11: emitc.call_opaque "std::fabs" + // cpp11-NEXT: emitc.call_opaque "std::fabs" + %0 = math.absf %arg0 : f32 + %1 = math.absf %arg1 : f64 + return +} + +func.func @floor(%arg0: f32, %arg1: f64) { + // c99: emitc.call_opaque "floorf" + // c99-NEXT: emitc.call_opaque "floor" + // cpp11: emitc.call_opaque "std::floor" + // cpp11-NEXT: emitc.call_opaque "std::floor" + %0 = math.floor %arg0 : f32 + %1 = math.floor %arg1 : f64 + return +} + +func.func @sin(%arg0: f32, %arg1: f64) { + // c99: emitc.call_opaque "sinf" + // c99-NEXT: emitc.call_opaque "sin" + // cpp11: emitc.call_opaque "std::sin" + // cpp11-NEXT: emitc.call_opaque "std::sin" + %0 = math.sin %arg0 : f32 + %1 = math.sin %arg1 : f64 + return +} + +func.func @cos(%arg0: f32, %arg1: f64) { + // c99: emitc.call_opaque "cosf" + // c99-NEXT: emitc.call_opaque "cos" + // cpp11: emitc.call_opaque "std::cos" + // cpp11-NEXT: emitc.call_opaque "std::cos" + %0 = math.cos %arg0 : f32 + %1 = math.cos %arg1 : f64 + return +} + +func.func @asin(%arg0: f32, %arg1: f64) { + // c99: emitc.call_opaque "asinf" + // c99-NEXT: emitc.call_opaque "asin" + // cpp11: emitc.call_opaque "std::asin" + // cpp11-NEXT: emitc.call_opaque "std::asin" + %0 = math.asin %arg0 : f32 + %1 = math.asin %arg1 : f64 + return +} + +func.func @acos(%arg0: f32, %arg1: f64) { + // c99: emitc.call_opaque "acosf" + // c99-NEXT: emitc.call_opaque "acos" + // cpp11: emitc.call_opaque "std::acos" + // cpp11-NEXT: emitc.call_opaque "std::acos" + %0 = math.acos %arg0 : f32 + %1 = math.acos %arg1 : f64 + return +} + +func.func @atan2(%arg0: f32, %arg1: f32, %arg2: f64, %arg3: f64) { + // c99: emitc.call_opaque "atan2f" + // c99-NEXT: emitc.call_opaque "atan2" + // cpp11: emitc.call_opaque "std::atan2" + // cpp11-NEXT: emitc.call_opaque "std::atan2" + %0 = math.atan2 %arg0, %arg1 : f32 + %1 = math.atan2 %arg2, %arg3 : f64 + return +} + +func.func @ceil(%arg0: f32, %arg1: f64) { + // c99: emitc.call_opaque "ceilf" + // c99-NEXT: emitc.call_opaque "ceil" + // cpp11: emitc.call_opaque "std::ceil" + // cpp11-NEXT: emitc.call_opaque "std::ceil" + %0 = math.ceil %arg0 : f32 + %1 = math.ceil %arg1 : f64 + return +} + +func.func @exp(%arg0: f32, %arg1: f64) { + // c99: emitc.call_opaque "expf" + // c99-NEXT: emitc.call_opaque "exp" + // cpp11: emitc.call_opaque "std::exp" + // cpp11-NEXT: emitc.call_opaque "std::exp" + %0 = math.exp %arg0 : f32 + %1 = math.exp %arg1 : f64 + return +} + +func.func @powf(%arg0: f32, %arg1: f32, %arg2: f64, %arg3: f64) { + // c99: emitc.call_opaque "powf" + // c99-NEXT: emitc.call_opaque "pow" + // cpp11: emitc.call_opaque "std::pow" + // cpp11-NEXT: emitc.call_opaque "std::pow" + %0 = math.powf %arg0, %arg1 : f32 + %1 = math.powf %arg2, %arg3 : f64 + return +} + +func.func @round(%arg0: f32, %arg1: f64) { + // c99: emitc.call_opaque "roundf" + // c99-NEXT: emitc.call_opaque "round" + // cpp11: emitc.call_opaque "std::round" + // cpp11-NEXT: emitc.call_opaque "std::round" + %0 = math.round %arg0 : f32 + %1 = math.round %arg1 : f64 + return +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index d3f3697..28127fb 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4313,6 +4313,7 @@ cc_library( ":IndexToLLVM", ":IndexToSPIRV", ":LinalgToStandard", + ":MathToEmitC", ":MathToFuncs", ":MathToLLVM", ":MathToLibm", @@ -8902,6 +8903,27 @@ cc_library( ) cc_library( + name = "MathToEmitC", + srcs = glob([ + "lib/Conversion/MathToEmitC/*.cpp", + ]), + hdrs = glob([ + "include/mlir/Conversion/MathToEmitC/*.h", + ]), + includes = [ + "include", + "lib/Conversion/MathToEmitC", + ], + deps = [ + ":ConversionPassIncGen", + ":EmitCDialect", + ":MathDialect", + ":Pass", + ":TransformUtils", + ], +) + +cc_library( name = "MathToFuncs", srcs = glob(["lib/Conversion/MathToFuncs/*.cpp"]), hdrs = glob(["include/mlir/Conversion/MathToFuncs/*.h"]), |