aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp12
-rw-r--r--mlir/test/mlir-cpu-runner/unranked_memref.mlir4
-rw-r--r--mlir/test/mlir-cpu-runner/utils.mlir4
3 files changed, 12 insertions, 8 deletions
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index febf4eb..cb66ae9 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -400,8 +400,13 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
// Insert before module terminator.
rewriter.setInsertionPoint(module.getBody(),
std::prev(module.getBody()->end()));
- rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType,
- ArrayRef<NamedAttribute>{});
+ FuncOp funcOp =
+ rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType,
+ ArrayRef<NamedAttribute>{});
+ // Insert a function attribute that will trigger the emission of the
+ // corresponding `_mlir_ciface_xxx` interface so that external libraries see
+ // a normalized ABI. This interface is added during std to llvm conversion.
+ funcOp.setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext()));
return fnNameAttr;
}
@@ -573,8 +578,7 @@ void ConvertLinalgToLLVMPass::runOnModule() {
LLVMTypeConverter converter(&getContext());
populateAffineToStdConversionPatterns(patterns, &getContext());
populateLoopToStdConversionPatterns(patterns, &getContext());
- populateStdToLLVMConversionPatterns(converter, patterns, /*useAlloca=*/false,
- /*emitCWrappers=*/true);
+ populateStdToLLVMConversionPatterns(converter, patterns);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(converter, patterns);
populateLinalgToStandardConversionPatterns(patterns, &getContext());
diff --git a/mlir/test/mlir-cpu-runner/unranked_memref.mlir b/mlir/test/mlir-cpu-runner/unranked_memref.mlir
index 7447e9d..13d4db9 100644
--- a/mlir/test/mlir-cpu-runner/unranked_memref.mlir
+++ b/mlir/test/mlir-cpu-runner/unranked_memref.mlir
@@ -55,5 +55,5 @@ func @main() -> () {
return
}
-func @print_memref_i8(memref<*xi8>)
-func @print_memref_f32(memref<*xf32>)
+func @print_memref_i8(memref<*xi8>) attributes { llvm.emit_c_interface }
+func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
diff --git a/mlir/test/mlir-cpu-runner/utils.mlir b/mlir/test/mlir-cpu-runner/utils.mlir
index 2a56920..d3ab617 100644
--- a/mlir/test/mlir-cpu-runner/utils.mlir
+++ b/mlir/test/mlir-cpu-runner/utils.mlir
@@ -50,7 +50,7 @@ func @print_3d() {
// PRINT-3D-NEXT: 2, 2, 4, 2, 2
// PRINT-3D-NEXT: 2, 2, 2, 2, 2
-func @print_memref_f32(memref<*xf32>)
+func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
!vector_type_C = type vector<4x4xf32>
!matrix_type_CC = type memref<1x1x!vector_type_C>
@@ -71,4 +71,4 @@ func @vector_splat_2d() {
// PRINT-VECTOR-SPLAT-2D: Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [1, 1] strides = [1, 1] data =
// PRINT-VECTOR-SPLAT-2D-NEXT: [((10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10), (10, 10, 10, 10))]
-func @print_memref_vector_4x4xf32(memref<?x?x!vector_type_C>)
+func @print_memref_vector_4x4xf32(memref<?x?x!vector_type_C>) attributes { llvm.emit_c_interface }