aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h5
-rw-r--r--mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp44
-rw-r--r--mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp6
-rw-r--r--mlir/test/Conversion/SPIRVToLLVM/module-ops-to-llvm.mlir26
4 files changed, 81 insertions, 0 deletions
diff --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
index 178d27d..1753723 100644
--- a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
+++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
@@ -43,6 +43,11 @@ void populateSPIRVToLLVMFunctionConversionPatterns(
MLIRContext *context, LLVMTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
+/// Populates the given patterns for module conversion from SPIR-V to LLVM.
+void populateSPIRVToLLVMModuleConversionPatterns(
+ MLIRContext *context, LLVMTypeConverter &typeConverter,
+ OwningRewritePatternList &patterns);
+
} // namespace mlir
#endif // MLIR_CONVERSION_SPIRVTOLLVM_CONVERTSPIRVTOLLVM_H
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index e32fdc5..297e73a 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -278,6 +278,43 @@ public:
return success();
}
};
+
+//===----------------------------------------------------------------------===//
+// ModuleOp conversion
+//===----------------------------------------------------------------------===//
+
+class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
+public:
+ using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(spirv::ModuleOp spvModuleOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto newModuleOp = rewriter.create<ModuleOp>(spvModuleOp.getLoc());
+ rewriter.inlineRegionBefore(spvModuleOp.body(), newModuleOp.getBody());
+
+ // Remove the terminator block that was automatically added by builder
+ rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
+ rewriter.eraseOp(spvModuleOp);
+ return success();
+ }
+};
+
+class ModuleEndConversionPattern
+ : public SPIRVToLLVMConversion<spirv::ModuleEndOp> {
+public:
+ using SPIRVToLLVMConversion<spirv::ModuleEndOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+
+ rewriter.replaceOpWithNewOp<ModuleTerminatorOp>(moduleEndOp);
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -361,3 +398,10 @@ void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
OwningRewritePatternList &patterns) {
patterns.insert<FuncConversionPattern>(context, typeConverter);
}
+
+void mlir::populateSPIRVToLLVMModuleConversionPatterns(
+ MLIRContext *context, LLVMTypeConverter &typeConverter,
+ OwningRewritePatternList &patterns) {
+ patterns.insert<ModuleConversionPattern, ModuleEndConversionPattern>(
+ context, typeConverter);
+}
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp
index 81a3a71..c512878 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp
@@ -34,6 +34,7 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
LLVMTypeConverter converter(&getContext());
OwningRewritePatternList patterns;
+ populateSPIRVToLLVMModuleConversionPatterns(context, converter, patterns);
populateSPIRVToLLVMConversionPatterns(context, converter, patterns);
populateSPIRVToLLVMFunctionConversionPatterns(context, converter, patterns);
@@ -45,6 +46,11 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
ConversionTarget target(getContext());
target.addIllegalDialect<spirv::SPIRVDialect>();
target.addLegalDialect<LLVM::LLVMDialect>();
+
+ // set `ModuleOp` and `ModuleTerminatorOp` as legal for `spv.module`
+ // conversion.
+ target.addLegalOp<ModuleOp>();
+ target.addLegalOp<ModuleTerminatorOp>();
if (failed(applyPartialConversion(module, target, patterns)))
signalPassFailure();
}
diff --git a/mlir/test/Conversion/SPIRVToLLVM/module-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/module-ops-to-llvm.mlir
new file mode 100644
index 0000000..b8169a1
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/module-ops-to-llvm.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.module
+//===----------------------------------------------------------------------===//
+
+// CHECK: module
+spv.module Logical GLSL450 {}
+
+// CHECK: module
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], [SPV_KHR_16bit_storage]> {}
+
+// CHECK: module
+spv.module Logical GLSL450 {
+ // CHECK: }
+ spv._module_end
+}
+
+// CHECK: module
+spv.module Logical GLSL450 {
+ // CHECK-LABEL: llvm.func @empty()
+ spv.func @empty() -> () "None" {
+ // CHECK: llvm.return
+ spv.Return
+ }
+}