aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp')
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp36
1 files changed, 35 insertions, 1 deletions
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index cf25c09..e78dd76 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -15,6 +15,7 @@
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -28,9 +29,11 @@ using namespace mlir;
namespace {
struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
+ using Base::Base;
void runOnOperation() override {
TypeConverter converter;
-
+ ConvertMemRefToEmitCOptions options;
+ options.lowerToCpp = this->lowerToCpp;
// Fallback for other types.
converter.addConversion([](Type type) -> std::optional<Type> {
if (!emitc::isSupportedEmitCType(type))
@@ -50,6 +53,37 @@ struct ConvertMemRefToEmitCPass
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
+
+ mlir::ModuleOp module = getOperation();
+ module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
+ if (callOp.getCallee() != alignedAllocFunctionName &&
+ callOp.getCallee() != mallocFunctionName) {
+ return mlir::WalkResult::advance();
+ }
+
+ for (auto &op : *module.getBody()) {
+ emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
+ if (!includeOp) {
+ continue;
+ }
+ if (includeOp.getIsStandardInclude() &&
+ ((options.lowerToCpp &&
+ includeOp.getInclude() == cppStandardLibraryHeader) ||
+ (!options.lowerToCpp &&
+ includeOp.getInclude() == cStandardLibraryHeader))) {
+ return mlir::WalkResult::interrupt();
+ }
+ }
+
+ mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
+ StringAttr includeAttr =
+ builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader
+ : cStandardLibraryHeader);
+ builder.create<mlir::emitc::IncludeOp>(
+ module.getLoc(), includeAttr,
+ /*is_standard_include=*/builder.getUnitAttr());
+ return mlir::WalkResult::interrupt();
+ });
}
};
} // namespace