//===- MathToEmitC.cpp - Math to EmitC Patterns -----------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/MathToEmitC/MathToEmitC.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; namespace { template class LowerToEmitCCallOpaque : public OpRewritePattern { std::string calleeStr; emitc::LanguageTarget languageTarget; public: LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr, emitc::LanguageTarget languageTarget) : OpRewritePattern(context), calleeStr(std::move(calleeStr)), languageTarget(languageTarget) {} LogicalResult matchAndRewrite(OpType op, PatternRewriter &rewriter) const override; }; template LogicalResult LowerToEmitCCallOpaque::matchAndRewrite( OpType op, PatternRewriter &rewriter) const { if (!llvm::all_of(op->getOperandTypes(), llvm::IsaPred) || !llvm::all_of(op->getResultTypes(), llvm::IsaPred)) return rewriter.notifyMatchFailure( op.getLoc(), "expected all operands and results to be of type f32 or f64"); std::string modifiedCalleeStr = calleeStr; if (languageTarget == emitc::LanguageTarget::cpp11) { modifiedCalleeStr = "std::" + calleeStr; } else if (languageTarget == emitc::LanguageTarget::c99) { auto operandType = op->getOperandTypes()[0]; if (operandType.isF32()) modifiedCalleeStr = calleeStr + "f"; } rewriter.replaceOpWithNewOp( op, op.getType(), modifiedCalleeStr, op->getOperands()); return success(); } } // namespace // Populates patterns to replace `math` operations with `emitc.call_opaque`, // using function names consistent with those in . void mlir::populateConvertMathToEmitCPatterns( RewritePatternSet &patterns, emitc::LanguageTarget languageTarget) { auto *context = patterns.getContext(); patterns.insert>(context, "floor", languageTarget); patterns.insert>(context, "round", languageTarget); patterns.insert>(context, "exp", languageTarget); patterns.insert>(context, "cos", languageTarget); patterns.insert>(context, "sin", languageTarget); patterns.insert>(context, "acos", languageTarget); patterns.insert>(context, "asin", languageTarget); patterns.insert>(context, "atan2", languageTarget); patterns.insert>(context, "ceil", languageTarget); patterns.insert>(context, "fabs", languageTarget); patterns.insert>(context, "pow", languageTarget); }