diff options
91 files changed, 692 insertions, 214 deletions
diff --git a/flang/unittests/Lower/OpenMPLoweringTest.cpp b/flang/unittests/Lower/OpenMPLoweringTest.cpp index ad6fe73..4c23845 100644 --- a/flang/unittests/Lower/OpenMPLoweringTest.cpp +++ b/flang/unittests/Lower/OpenMPLoweringTest.cpp @@ -15,8 +15,7 @@ class OpenMPLoweringTest : public testing::Test { protected: void SetUp() override { - mlir::registerDialect<mlir::omp::OpenMPDialect>(); - mlir::registerAllDialects(&ctx); + ctx.getOrLoadDialect<mlir::omp::OpenMPDialect>(); mlirOpBuilder.reset(new mlir::OpBuilder(&ctx)); } diff --git a/mlir/examples/standalone/standalone-opt/standalone-opt.cpp b/mlir/examples/standalone/standalone-opt/standalone-opt.cpp index 5c99058..eb624b3 100644 --- a/mlir/examples/standalone/standalone-opt/standalone-opt.cpp +++ b/mlir/examples/standalone/standalone-opt/standalone-opt.cpp @@ -76,7 +76,7 @@ int main(int argc, char **argv) { if (showDialects) { mlir::MLIRContext context; llvm::outs() << "Registered Dialects:\n"; - for (mlir::Dialect *dialect : context.getRegisteredDialects()) { + for (mlir::Dialect *dialect : context.getLoadedDialects()) { llvm::outs() << dialect->getNamespace() << "\n"; } return 0; diff --git a/mlir/examples/toy/Ch2/toyc.cpp b/mlir/examples/toy/Ch2/toyc.cpp index d0880ce..99232d8 100644 --- a/mlir/examples/toy/Ch2/toyc.cpp +++ b/mlir/examples/toy/Ch2/toyc.cpp @@ -68,10 +68,9 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) { } int dumpMLIR() { - // Register our Dialect with MLIR. - mlir::registerDialect<mlir::toy::ToyDialect>(); - - mlir::MLIRContext context; + mlir::MLIRContext context(/*loadAllDialects=*/false); + // Load our Dialect in this MLIR Context. + context.getOrLoadDialect<mlir::toy::ToyDialect>(); // Handle '.toy' input to the compiler. if (inputType != InputType::MLIR && diff --git a/mlir/examples/toy/Ch3/toyc.cpp b/mlir/examples/toy/Ch3/toyc.cpp index f9d5631..d0430ce 100644 --- a/mlir/examples/toy/Ch3/toyc.cpp +++ b/mlir/examples/toy/Ch3/toyc.cpp @@ -102,10 +102,10 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, } int dumpMLIR() { - // Register our Dialect with MLIR. - mlir::registerDialect<mlir::toy::ToyDialect>(); + mlir::MLIRContext context(/*loadAllDialects=*/false); + // Load our Dialect in this MLIR Context. + context.getOrLoadDialect<mlir::toy::ToyDialect>(); - mlir::MLIRContext context; mlir::OwningModuleRef module; llvm::SourceMgr sourceMgr; mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); diff --git a/mlir/examples/toy/Ch4/toyc.cpp b/mlir/examples/toy/Ch4/toyc.cpp index e11f35c..9f95887 100644 --- a/mlir/examples/toy/Ch4/toyc.cpp +++ b/mlir/examples/toy/Ch4/toyc.cpp @@ -103,10 +103,10 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, } int dumpMLIR() { - // Register our Dialect with MLIR. - mlir::registerDialect<mlir::toy::ToyDialect>(); + mlir::MLIRContext context(/*loadAllDialects=*/false); + // Load our Dialect in this MLIR Context. + context.getOrLoadDialect<mlir::toy::ToyDialect>(); - mlir::MLIRContext context; mlir::OwningModuleRef module; llvm::SourceMgr sourceMgr; mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp index 3097681..1077fc9 100644 --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -256,6 +256,10 @@ struct TransposeOpLowering : public ConversionPattern { namespace { struct ToyToAffineLoweringPass : public PassWrapper<ToyToAffineLoweringPass, FunctionPass> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<AffineDialect>(); + registry.insert<StandardOpsDialect>(); + } void runOnFunction() final; }; } // end anonymous namespace. diff --git a/mlir/examples/toy/Ch5/toyc.cpp b/mlir/examples/toy/Ch5/toyc.cpp index ed04969..16faac0 100644 --- a/mlir/examples/toy/Ch5/toyc.cpp +++ b/mlir/examples/toy/Ch5/toyc.cpp @@ -106,10 +106,10 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, } int dumpMLIR() { - // Register our Dialect with MLIR. - mlir::registerDialect<mlir::toy::ToyDialect>(); + mlir::MLIRContext context(/*loadAllDialects=*/false); + // Load our Dialect in this MLIR Context. + context.getOrLoadDialect<mlir::toy::ToyDialect>(); - mlir::MLIRContext context; mlir::OwningModuleRef module; llvm::SourceMgr sourceMgr; mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp index cac3415..9ff9eb4 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -255,6 +255,10 @@ struct TransposeOpLowering : public ConversionPattern { namespace { struct ToyToAffineLoweringPass : public PassWrapper<ToyToAffineLoweringPass, FunctionPass> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<AffineDialect>(); + registry.insert<StandardOpsDialect>(); + } void runOnFunction() final; }; } // end anonymous namespace. diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp index 74b32dc..8020fb3 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -159,6 +159,10 @@ private: namespace { struct ToyToLLVMLoweringPass : public PassWrapper<ToyToLLVMLoweringPass, OperationPass<ModuleOp>> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<LLVM::LLVMDialect>(); + registry.insert<scf::SCFDialect>(); + } void runOnOperation() final; }; } // end anonymous namespace diff --git a/mlir/examples/toy/Ch6/toyc.cpp b/mlir/examples/toy/Ch6/toyc.cpp index bdcdf1a..9504a38 100644 --- a/mlir/examples/toy/Ch6/toyc.cpp +++ b/mlir/examples/toy/Ch6/toyc.cpp @@ -255,10 +255,10 @@ int main(int argc, char **argv) { // If we aren't dumping the AST, then we are compiling with/to MLIR. - // Register our Dialect with MLIR. - mlir::registerDialect<mlir::toy::ToyDialect>(); + mlir::MLIRContext context(/*loadAllDialects=*/false); + // Load our Dialect in this MLIR Context. + context.getOrLoadDialect<mlir::toy::ToyDialect>(); - mlir::MLIRContext context; mlir::OwningModuleRef module; if (int error = loadAndProcessMLIR(context, module)) return error; diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp index 3097681..1077fc9 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -256,6 +256,10 @@ struct TransposeOpLowering : public ConversionPattern { namespace { struct ToyToAffineLoweringPass : public PassWrapper<ToyToAffineLoweringPass, FunctionPass> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<AffineDialect>(); + registry.insert<StandardOpsDialect>(); + } void runOnFunction() final; }; } // end anonymous namespace. diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp index 74b32dc..8020fb3 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -159,6 +159,10 @@ private: namespace { struct ToyToLLVMLoweringPass : public PassWrapper<ToyToLLVMLoweringPass, OperationPass<ModuleOp>> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<LLVM::LLVMDialect>(); + registry.insert<scf::SCFDialect>(); + } void runOnOperation() final; }; } // end anonymous namespace diff --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp index c1cc207..cb3b455 100644 --- a/mlir/examples/toy/Ch7/toyc.cpp +++ b/mlir/examples/toy/Ch7/toyc.cpp @@ -256,10 +256,10 @@ int main(int argc, char **argv) { // If we aren't dumping the AST, then we are compiling with/to MLIR. - // Register our Dialect with MLIR. - mlir::registerDialect<mlir::toy::ToyDialect>(); + mlir::MLIRContext context(/*loadAllDialects=*/false); + // Load our Dialect in this MLIR Context. + context.getOrLoadDialect<mlir::toy::ToyDialect>(); - mlir::MLIRContext context; mlir::OwningModuleRef module; if (int error = loadAndProcessMLIR(context, module)) return error; diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 6b5be2d..f9ec4d1 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -90,6 +90,12 @@ MlirContext mlirContextCreate(); /** Takes an MLIR context owned by the caller and destroys it. */ void mlirContextDestroy(MlirContext context); +/** Load all the globally registered dialects in the provided context. + * TODO: remove the concept of globally registered dialect by exposing the + * DialectRegistry. + */ +void mlirContextLoadAllDialects(MlirContext context); + /*============================================================================*/ /* Location API. */ /*============================================================================*/ diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 4d4fe06..0c40bb3 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -66,6 +66,11 @@ def ConvertAffineToStandard : Pass<"lower-affine"> { `affine.apply`. }]; let constructor = "mlir::createLowerAffinePass()"; + let dependentDialects = [ + "scf::SCFDialect", + "StandardOpsDialect", + "vector::VectorDialect" + ]; } //===----------------------------------------------------------------------===// @@ -76,6 +81,7 @@ def ConvertAVX512ToLLVM : Pass<"convert-avx512-to-llvm", "ModuleOp"> { let summary = "Convert the operations from the avx512 dialect into the LLVM " "dialect"; let constructor = "mlir::createConvertAVX512ToLLVMPass()"; + let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMAVX512Dialect"]; } //===----------------------------------------------------------------------===// @@ -98,6 +104,7 @@ def GpuToLLVMConversionPass : Pass<"gpu-to-llvm", "ModuleOp"> { def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> { let summary = "Generate NVVM operations for gpu operations"; let constructor = "mlir::createLowerGpuOpsToNVVMOpsPass()"; + let dependentDialects = ["NVVM::NVVMDialect"]; let options = [ Option<"indexBitwidth", "index-bitwidth", "unsigned", /*default=kDeriveIndexBitwidthFromDataLayout*/"0", @@ -112,6 +119,7 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> { def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> { let summary = "Generate ROCDL operations for gpu operations"; let constructor = "mlir::createLowerGpuOpsToROCDLOpsPass()"; + let dependentDialects = ["ROCDL::ROCDLDialect"]; let options = [ Option<"indexBitwidth", "index-bitwidth", "unsigned", /*default=kDeriveIndexBitwidthFromDataLayout*/"0", @@ -126,6 +134,7 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> { def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> { let summary = "Convert GPU dialect to SPIR-V dialect"; let constructor = "mlir::createConvertGPUToSPIRVPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; } //===----------------------------------------------------------------------===// @@ -136,6 +145,7 @@ def ConvertGpuLaunchFuncToVulkanLaunchFunc : Pass<"convert-gpu-launch-to-vulkan-launch", "ModuleOp"> { let summary = "Convert gpu.launch_func to vulkanLaunch external call"; let constructor = "mlir::createConvertGpuLaunchFuncToVulkanLaunchFuncPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; } def ConvertVulkanLaunchFuncToVulkanCalls @@ -143,6 +153,7 @@ def ConvertVulkanLaunchFuncToVulkanCalls let summary = "Convert vulkanLaunch external call to Vulkan runtime external " "calls"; let constructor = "mlir::createConvertVulkanLaunchFuncToVulkanCallsPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; } //===----------------------------------------------------------------------===// @@ -153,6 +164,7 @@ def ConvertLinalgToLLVM : Pass<"convert-linalg-to-llvm", "ModuleOp"> { let summary = "Convert the operations from the linalg dialect into the LLVM " "dialect"; let constructor = "mlir::createConvertLinalgToLLVMPass()"; + let dependentDialects = ["scf::SCFDialect", "LLVM::LLVMDialect"]; } //===----------------------------------------------------------------------===// @@ -163,6 +175,7 @@ def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> { let summary = "Convert the operations from the linalg dialect into the " "Standard dialect"; let constructor = "mlir::createConvertLinalgToStandardPass()"; + let dependentDialects = ["StandardOpsDialect"]; } //===----------------------------------------------------------------------===// @@ -172,6 +185,7 @@ def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> { def ConvertLinalgToSPIRV : Pass<"convert-linalg-to-spirv", "ModuleOp"> { let summary = "Convert Linalg ops to SPIR-V ops"; let constructor = "mlir::createLinalgToSPIRVPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; } //===----------------------------------------------------------------------===// @@ -182,6 +196,7 @@ def SCFToStandard : Pass<"convert-scf-to-std"> { let summary = "Convert SCF dialect to Standard dialect, replacing structured" " control flow with a CFG"; let constructor = "mlir::createLowerToCFGPass()"; + let dependentDialects = ["StandardOpsDialect"]; } //===----------------------------------------------------------------------===// @@ -191,6 +206,7 @@ def SCFToStandard : Pass<"convert-scf-to-std"> { def ConvertAffineForToGPU : FunctionPass<"convert-affine-for-to-gpu"> { let summary = "Convert top-level AffineFor Ops to GPU kernels"; let constructor = "mlir::createAffineForToGPUPass()"; + let dependentDialects = ["gpu::GPUDialect"]; let options = [ Option<"numBlockDims", "gpu-block-dims", "unsigned", /*default=*/"1u", "Number of GPU block dimensions for mapping">, @@ -202,6 +218,7 @@ def ConvertAffineForToGPU : FunctionPass<"convert-affine-for-to-gpu"> { def ConvertParallelLoopToGpu : Pass<"convert-parallel-loops-to-gpu"> { let summary = "Convert mapped scf.parallel ops to gpu launch operations"; let constructor = "mlir::createParallelLoopToGpuPass()"; + let dependentDialects = ["AffineDialect", "gpu::GPUDialect"]; } //===----------------------------------------------------------------------===// @@ -212,6 +229,7 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> { let summary = "Convert operations from the shape dialect into the standard " "dialect"; let constructor = "mlir::createConvertShapeToStandardPass()"; + let dependentDialects = ["StandardOpsDialect"]; } //===----------------------------------------------------------------------===// @@ -221,6 +239,7 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> { def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> { let summary = "Convert operations from the shape dialect to the SCF dialect"; let constructor = "mlir::createConvertShapeToSCFPass()"; + let dependentDialects = ["scf::SCFDialect"]; } //===----------------------------------------------------------------------===// @@ -230,6 +249,7 @@ def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> { def ConvertSPIRVToLLVM : Pass<"convert-spirv-to-llvm", "ModuleOp"> { let summary = "Convert SPIR-V dialect to LLVM dialect"; let constructor = "mlir::createConvertSPIRVToLLVMPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; } //===----------------------------------------------------------------------===// @@ -264,6 +284,7 @@ def ConvertStandardToLLVM : Pass<"convert-std-to-llvm", "ModuleOp"> { LLVM IR types. }]; let constructor = "mlir::createLowerToLLVMPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; let options = [ Option<"useAlignedAlloc", "use-aligned-alloc", "bool", /*default=*/"false", "Use aligned_alloc in place of malloc for heap allocations">, @@ -287,11 +308,13 @@ def ConvertStandardToLLVM : Pass<"convert-std-to-llvm", "ModuleOp"> { def LegalizeStandardForSPIRV : Pass<"legalize-std-for-spirv"> { let summary = "Legalize standard ops for SPIR-V lowering"; let constructor = "mlir::createLegalizeStdOpsForSPIRVLoweringPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; } def ConvertStandardToSPIRV : Pass<"convert-std-to-spirv", "ModuleOp"> { let summary = "Convert Standard Ops to SPIR-V dialect"; let constructor = "mlir::createConvertStandardToSPIRVPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; } //===----------------------------------------------------------------------===// @@ -302,6 +325,7 @@ def ConvertVectorToSCF : FunctionPass<"convert-vector-to-scf"> { let summary = "Lower the operations from the vector dialect into the SCF " "dialect"; let constructor = "mlir::createConvertVectorToSCFPass()"; + let dependentDialects = ["AffineDialect", "scf::SCFDialect"]; let options = [ Option<"fullUnroll", "full-unroll", "bool", /*default=*/"false", "Perform full unrolling when converting vector transfers to SCF">, @@ -316,6 +340,7 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> { let summary = "Lower the operations from the vector dialect into the LLVM " "dialect"; let constructor = "mlir::createConvertVectorToLLVMPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; let options = [ Option<"reassociateFPReductions", "reassociate-fp-reductions", "bool", /*default=*/"false", @@ -331,6 +356,7 @@ def ConvertVectorToROCDL : Pass<"convert-vector-to-rocdl", "ModuleOp"> { let summary = "Lower the operations from the vector dialect into the ROCDL " "dialect"; let constructor = "mlir::createConvertVectorToROCDLPass()"; + let dependentDialects = ["ROCDL::ROCDLDialect"]; } #endif // MLIR_CONVERSION_PASSES diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td index 8106400..f43fabd 100644 --- a/mlir/include/mlir/Dialect/Affine/Passes.td +++ b/mlir/include/mlir/Dialect/Affine/Passes.td @@ -94,6 +94,7 @@ def AffineLoopUnrollAndJam : FunctionPass<"affine-loop-unroll-jam"> { def AffineVectorize : FunctionPass<"affine-super-vectorize"> { let summary = "Vectorize to a target independent n-D vector abstraction"; let constructor = "mlir::createSuperVectorizePass()"; + let dependentDialects = ["vector::VectorDialect"]; let options = [ ListOption<"vectorSizes", "virtual-vector-size", "int64_t", "Specify an n-D virtual vector size for vectorization", diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index 04700f0..2f465f0 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -15,6 +15,7 @@ #define MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_ #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index d21f5bc..2617a2d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -19,6 +19,11 @@ include "mlir/IR/OpBase.td" def LLVM_Dialect : Dialect { let name = "llvm"; let cppNamespace = "LLVM"; + + /// FIXME: at the moment this is a dependency of the translation to LLVM IR, + /// not really one of this dialect per-se. + let dependentDialects = [ "omp::OpenMPDialect" ]; + let hasRegionArgAttrVerify = 1; let extraClassDeclaration = [{ ~LLVMDialect(); diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h index 86d437c..9cc5314 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_ #define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 5f022e32..0e5bc16 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -23,6 +23,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" def NVVM_Dialect : Dialect { let name = "nvvm"; let cppNamespace = "NVVM"; + let dependentDialects = [ "LLVM::LLVMDialect" ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h index bf761c3..eb40373 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h @@ -22,6 +22,7 @@ #ifndef MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_ #define MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 0cd1169..f9aca5a 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -23,6 +23,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" def ROCDL_Dialect : Dialect { let name = "rocdl"; let cppNamespace = "ROCDL"; + let dependentDialects = [ "LLVM::LLVMDialect" ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index 11f12ad..dcf4b5e 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -30,17 +30,20 @@ def LinalgFusion : FunctionPass<"linalg-fusion"> { def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> { let summary = "Fuse operations on RankedTensorType in linalg dialect"; let constructor = "mlir::createLinalgFusionOfTensorOpsPass()"; + let dependentDialects = ["AffineDialect"]; } def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> { let summary = "Lower the operations from the linalg dialect into affine " "loops"; let constructor = "mlir::createConvertLinalgToAffineLoopsPass()"; + let dependentDialects = ["AffineDialect"]; } def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> { let summary = "Lower the operations from the linalg dialect into loops"; let constructor = "mlir::createConvertLinalgToLoopsPass()"; + let dependentDialects = ["scf::SCFDialect", "AffineDialect"]; } def LinalgOnTensorsToBuffers : Pass<"convert-linalg-on-tensors-to-buffers", "ModuleOp"> { @@ -54,6 +57,7 @@ def LinalgLowerToParallelLoops let summary = "Lower the operations from the linalg dialect into parallel " "loops"; let constructor = "mlir::createConvertLinalgToParallelLoopsPass()"; + let dependentDialects = ["AffineDialect", "scf::SCFDialect"]; } def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> { @@ -70,6 +74,9 @@ def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> { def LinalgTiling : FunctionPass<"linalg-tile"> { let summary = "Tile operations in the linalg dialect"; let constructor = "mlir::createLinalgTilingPass()"; + let dependentDialects = [ + "AffineDialect", "scf::SCFDialect" + ]; let options = [ ListOption<"tileSizes", "linalg-tile-sizes", "int64_t", "Test generation of dynamic promoted buffers", @@ -86,6 +93,7 @@ def LinalgTilingToParallelLoops "Test generation of dynamic promoted buffers", "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> ]; + let dependentDialects = ["AffineDialect", "scf::SCFDialect"]; } #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td index 483d0ba..6f3cf0e 100644 --- a/mlir/include/mlir/Dialect/SCF/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Passes.td @@ -36,6 +36,7 @@ def SCFParallelLoopTiling : FunctionPass<"parallel-loop-tiling"> { "Factors to tile parallel loops by", "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> ]; + let dependentDialects = ["AffineDialect"]; } #endif // MLIR_DIALECT_SCF_PASSES diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index 4f9e4cb..d00d86db 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -16,6 +16,8 @@ #include "mlir/IR/OperationSupport.h" #include "mlir/Support/TypeID.h" +#include <map> + namespace mlir { class DialectAsmParser; class DialectAsmPrinter; @@ -23,7 +25,7 @@ class DialectInterface; class OpBuilder; class Type; -using DialectAllocatorFunction = std::function<void(MLIRContext *)>; +using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>; /// Dialects are groups of MLIR operations and behavior associated with the /// entire group. For example, hooks into other systems for constant folding, @@ -212,30 +214,81 @@ private: /// A collection of registered dialect interfaces. DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces; - /// Registers a specific dialect creation function with the global registry. - /// Used through the registerDialect template. - /// Registrations are deduplicated by dialect TypeID and only the first - /// registration will be used. - static void - registerDialectAllocator(TypeID typeID, - const DialectAllocatorFunction &function); - template <typename ConcreteDialect> friend void registerDialect(); friend class MLIRContext; }; -/// Registers all dialects and hooks from the global registries with the -/// specified MLIRContext. +/// The DialectRegistry maps a dialect namespace to a constructor for the +/// matching dialect. +/// This allows for decoupling the list of dialects "available" from the +/// dialects loaded in the Context. The parser in particular will lazily load +/// dialects in in the Context as operations are encountered. +class DialectRegistry { + using MapTy = + std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>; + +public: + template <typename ConcreteDialect> void insert() { + insert(TypeID::get<ConcreteDialect>(), + ConcreteDialect::getDialectNamespace(), + static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) { + // Just allocate the dialect, the context + // takes ownership of it. + return ctx->getOrLoadDialect<ConcreteDialect>(); + }))); + } + + /// Add a new dialect constructor to the registry. + void insert(TypeID typeID, StringRef name, DialectAllocatorFunction ctor); + + /// Load a dialect for this namespace in the provided context. + Dialect *loadByName(StringRef name, MLIRContext *context); + + // Register all dialects available in the current registry with the registry + // in the provided context. + void appendTo(DialectRegistry &destination) { + for (const auto &name_and_registration_it : registry) + destination.insert(name_and_registration_it.second.first, + name_and_registration_it.first, + name_and_registration_it.second.second); + } + // Load all dialects available in the registry in the provided context. + void loadAll(MLIRContext *context) { + for (const auto &name_and_registration_it : registry) + name_and_registration_it.second.second(context); + } + + MapTy::const_iterator begin() const { return registry.begin(); } + MapTy::const_iterator end() const { return registry.end(); } + +private: + MapTy registry; +}; + +/// Deprecated: this provides a global registry for convenience, while we're +/// transitionning the registration mechanism to a stateless approach. +DialectRegistry &getGlobalDialectRegistry(); + +/// Registers all dialects from the global registries with the +/// specified MLIRContext. This won't load the dialects in the context, +/// but only make them available for lazy loading by name. /// Note: This method is not thread-safe. -void registerAllDialects(MLIRContext *context); +inline void registerAllDialects(MLIRContext *context) { + getGlobalDialectRegistry().appendTo(context->getDialectRegistry()); +} + +/// Register and return the dialect with the given namespace in the provided +/// context. Returns nullptr is there is no constructor registered for this +/// dialect. +inline Dialect *registerDialect(StringRef name, MLIRContext *context) { + return getGlobalDialectRegistry().loadByName(name, context); +} /// Utility to register a dialect. Client can register their dialect with the /// global registry by calling registerDialect<MyDialect>(); /// Note: This method is not thread-safe. template <typename ConcreteDialect> void registerDialect() { - Dialect::registerDialectAllocator( - TypeID::get<ConcreteDialect>(), - [](MLIRContext *ctx) { ctx->getOrCreateDialect<ConcreteDialect>(); }); + getGlobalDialectRegistry().insert<ConcreteDialect>(); } /// DialectRegistration provides a global initializer that registers a Dialect diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h index 7e281f3..3d467cd 100644 --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -428,7 +428,7 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) { if (!attr.first.strref().contains('.')) return funcOp.emitOpError("arguments may only have dialect attributes"); auto dialectNamePair = attr.first.strref().split('.'); - if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) { + if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) { if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0, /*argIndex=*/i, attr))) return failure(); @@ -444,7 +444,7 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) { if (!attr.first.strref().contains('.')) return funcOp.emitOpError("results may only have dialect attributes"); auto dialectNamePair = attr.first.strref().split('.'); - if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) { + if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) { if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0, /*resultIndex=*/i, attr))) diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h index 0192a8a..d406c30 100644 --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -19,10 +19,12 @@ namespace mlir { class AbstractOperation; class DiagnosticEngine; class Dialect; +class DialectRegistry; class InFlightDiagnostic; class Location; class MLIRContextImpl; class StorageUniquer; +DialectRegistry &getGlobalDialectRegistry(); /// MLIRContext is the top-level object for a collection of MLIR modules. It /// holds immortal uniqued objects like types, and the tables used to unique @@ -34,34 +36,54 @@ class StorageUniquer; /// class MLIRContext { public: - explicit MLIRContext(); + /// Create a new Context. + /// The loadAllDialects parameters allows to load all dialects from the global + /// registry on Context construction. It is deprecated and will be removed + /// soon. + explicit MLIRContext(bool loadAllDialects = true); ~MLIRContext(); - /// Return information about all registered IR dialects. - std::vector<Dialect *> getRegisteredDialects(); + /// Return information about all IR dialects loaded in the context. + std::vector<Dialect *> getLoadedDialects(); + + /// Return the dialect registry associated with this context. + DialectRegistry &getDialectRegistry(); + + /// Return information about all available dialects in the registry in this + /// context. + std::vector<StringRef> getAvailableDialects(); /// Get a registered IR dialect with the given namespace. If an exact match is /// not found, then return nullptr. - Dialect *getRegisteredDialect(StringRef name); + Dialect *getLoadedDialect(StringRef name); /// Get a registered IR dialect for the given derived dialect type. The /// derived type must provide a static 'getDialectNamespace' method. - template <typename T> T *getRegisteredDialect() { - return static_cast<T *>(getRegisteredDialect(T::getDialectNamespace())); + template <typename T> T *getLoadedDialect() { + return static_cast<T *>(getLoadedDialect(T::getDialectNamespace())); } /// Get (or create) a dialect for the given derived dialect type. The derived /// type must provide a static 'getDialectNamespace' method. - template <typename T> - T *getOrCreateDialect() { - return static_cast<T *>(getOrCreateDialect( - T::getDialectNamespace(), TypeID::get<T>(), [this]() { + template <typename T> T *getOrLoadDialect() { + return static_cast<T *>( + getOrLoadDialect(T::getDialectNamespace(), TypeID::get<T>(), [this]() { std::unique_ptr<T> dialect(new T(this)); - dialect->dialectID = TypeID::get<T>(); return dialect; })); } + /// Deprecated: load all globally registered dialects into this context. + /// This method will be removed soon, it can be used temporarily as we're + /// phasing out the global registry. + void loadAllGloballyRegisteredDialects(); + + /// Get (or create) a dialect for the given derived dialect name. + /// The dialect will be loaded from the registry if no dialect is found. + /// If no dialect is loaded for this name and none is available in the + /// registry, returns nullptr. + Dialect *getOrLoadDialect(StringRef name); + /// Return true if we allow to create operation for unregistered dialects. bool allowsUnregisteredDialects(); @@ -123,10 +145,12 @@ private: const std::unique_ptr<MLIRContextImpl> impl; /// Get a dialect for the provided namespace and TypeID: abort the program if - /// a dialect exist for this namespace with different TypeID. Returns a - /// pointer to the dialect owned by the context. - Dialect *getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID, - function_ref<std::unique_ptr<Dialect>()> ctor); + /// a dialect exist for this namespace with different TypeID. If a dialect has + /// not been loaded for this namespace/TypeID yet, use the provided ctor to + /// create one on the fly and load it. Returns a pointer to the dialect owned + /// by the context. + Dialect *getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, + function_ref<std::unique_ptr<Dialect>()> ctor); MLIRContext(const MLIRContext &) = delete; void operator=(const MLIRContext &) = delete; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 9cc57a6..a28410f 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -244,6 +244,11 @@ class Dialect { // The description of the dialect. string description = ?; + // A list of dialects this dialect will load on construction as dependencies. + // These are dialects that this dialect may involved in canonicalization + // pattern or interfaces. + list<string> dependentDialects = []; + // The C++ namespace that ops of this dialect should be placed into. // // By default, uses the name of the dialect as the only namespace. To avoid diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index b76b26f..a456616 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -35,29 +35,32 @@ namespace mlir { +// Add all the MLIR dialects to the provided registry. +inline void registerAllDialects(DialectRegistry ®istry) { + registry.insert<acc::OpenACCDialect>(); + registry.insert<AffineDialect>(); + registry.insert<avx512::AVX512Dialect>(); + registry.insert<gpu::GPUDialect>(); + registry.insert<LLVM::LLVMAVX512Dialect>(); + registry.insert<LLVM::LLVMDialect>(); + registry.insert<linalg::LinalgDialect>(); + registry.insert<scf::SCFDialect>(); + registry.insert<omp::OpenMPDialect>(); + registry.insert<quant::QuantizationDialect>(); + registry.insert<spirv::SPIRVDialect>(); + registry.insert<StandardOpsDialect>(); + registry.insert<vector::VectorDialect>(); + registry.insert<NVVM::NVVMDialect>(); + registry.insert<ROCDL::ROCDLDialect>(); + registry.insert<SDBMDialect>(); + registry.insert<shape::ShapeDialect>(); +} + // This function should be called before creating any MLIRContext if one expect // all the possible dialects to be made available to the context automatically. inline void registerAllDialects() { - static bool init_once = []() { - registerDialect<acc::OpenACCDialect>(); - registerDialect<AffineDialect>(); - registerDialect<avx512::AVX512Dialect>(); - registerDialect<gpu::GPUDialect>(); - registerDialect<LLVM::LLVMAVX512Dialect>(); - registerDialect<LLVM::LLVMDialect>(); - registerDialect<linalg::LinalgDialect>(); - registerDialect<scf::SCFDialect>(); - registerDialect<omp::OpenMPDialect>(); - registerDialect<quant::QuantizationDialect>(); - registerDialect<spirv::SPIRVDialect>(); - registerDialect<StandardOpsDialect>(); - registerDialect<vector::VectorDialect>(); - registerDialect<NVVM::NVVMDialect>(); - registerDialect<ROCDL::ROCDLDialect>(); - registerDialect<SDBMDialect>(); - registerDialect<shape::ShapeDialect>(); - return true; - }(); + static bool init_once = + ([]() { registerAllDialects(getGlobalDialectRegistry()); }(), true); (void)init_once; } } // namespace mlir diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index 7c0f9bd..ea361ae 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -9,6 +9,7 @@ #ifndef MLIR_PASS_PASS_H #define MLIR_PASS_PASS_H +#include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" #include "mlir/Pass/AnalysisManager.h" #include "mlir/Pass/PassRegistry.h" @@ -57,6 +58,13 @@ public: /// Returns the derived pass name. virtual StringRef getName() const = 0; + /// Register dependent dialects for the current pass. + /// A pass is expected to register the dialects it will create operations for, + /// other than dialect that exists in the input. For example, a pass that + /// converts from Linalg to Affine would register the Affine dialect but does + /// not need to register Linalg. + virtual void getDependentDialects(DialectRegistry ®istry) const {} + /// Returns the command line argument used when registering this pass. Return /// an empty string if one does not exist. virtual StringRef getArgument() const { diff --git a/mlir/include/mlir/Pass/PassBase.td b/mlir/include/mlir/Pass/PassBase.td index 54b4403..749d042 100644 --- a/mlir/include/mlir/Pass/PassBase.td +++ b/mlir/include/mlir/Pass/PassBase.td @@ -78,6 +78,9 @@ class PassBase<string passArg, string base> { // A C++ constructor call to create an instance of this pass. code constructor = [{}]; + // A list of dialects this pass may produce operations in. + list<string> dependentDialects = []; + // A set of options provided by this pass. list<Option> options = []; diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h index 9cbfb0b..29e7c07 100644 --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -9,6 +9,7 @@ #ifndef MLIR_PASS_PASSMANAGER_H #define MLIR_PASS_PASSMANAGER_H +#include "mlir/IR/Dialect.h" #include "mlir/IR/OperationSupport.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/Optional.h" @@ -58,6 +59,14 @@ public: pass_iterator end(); iterator_range<pass_iterator> getPasses() { return {begin(), end()}; } + using const_pass_iterator = llvm::pointee_iterator< + std::vector<std::unique_ptr<Pass>>::const_iterator>; + const_pass_iterator begin() const; + const_pass_iterator end() const; + iterator_range<const_pass_iterator> getPasses() const { + return {begin(), end()}; + } + /// Run the held passes over the given operation. LogicalResult run(Operation *op, AnalysisManager am); @@ -100,6 +109,11 @@ public: /// Merge the pass statistics of this class into 'other'. void mergeStatisticsInto(OpPassManager &other); + /// Register dependent dialects for the current pass manager. + /// This is forwarding to every pass in this PassManager, see the + /// documentation for the same method on the Pass class. + void getDependentDialects(DialectRegistry &dialects) const; + private: OpPassManager(OperationName name, bool verifyPasses); diff --git a/mlir/include/mlir/Support/MlirOptMain.h b/mlir/include/mlir/Support/MlirOptMain.h index f235ea3..741276b 100644 --- a/mlir/include/mlir/Support/MlirOptMain.h +++ b/mlir/include/mlir/Support/MlirOptMain.h @@ -22,10 +22,15 @@ namespace mlir { struct LogicalResult; class PassPipelineCLParser; +/// Run an passPipeline on the provided memory buffer loaded as an MLIRModule. +/// The preloadDialectsInContext option will trigger an option upfront loading +/// of all dialects from the global registry in the MLIRContext. This option is +/// deprecated and will be removed soon. LogicalResult MlirOptMain(llvm::raw_ostream &os, std::unique_ptr<llvm::MemoryBuffer> buffer, const PassPipelineCLParser &passPipeline, bool splitInputFile, bool verifyDiagnostics, - bool verifyPasses, bool allowUnregisteredDialects); + bool verifyPasses, bool allowUnregisteredDialects, + bool preloadDialectsInContext = false); } // end namespace mlir diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h index 5e85806..99217d8 100644 --- a/mlir/include/mlir/TableGen/Dialect.h +++ b/mlir/include/mlir/TableGen/Dialect.h @@ -14,6 +14,7 @@ #include "mlir/Support/LLVM.h" #include <string> +#include <vector> namespace llvm { class Record; @@ -25,7 +26,7 @@ namespace tblgen { // and provides helper methods for accessing them. class Dialect { public: - explicit Dialect(const llvm::Record *def) : def(def) {} + explicit Dialect(const llvm::Record *def); // Returns the name of this dialect. StringRef getName() const; @@ -43,6 +44,10 @@ public: // Returns the description of the dialect. Returns empty string if none. StringRef getDescription() const; + // Returns the list of dialect (class names) that this dialect depends on. + // These are dialects that will be loaded on construction of this dialect. + ArrayRef<StringRef> getDependentDialects() const; + // Returns the dialects extra class declaration code. llvm::Optional<StringRef> getExtraClassDeclaration() const; @@ -70,6 +75,7 @@ public: private: const llvm::Record *def; + std::vector<StringRef> dependentDialects; }; } // end namespace tblgen } // end namespace mlir diff --git a/mlir/include/mlir/TableGen/Pass.h b/mlir/include/mlir/TableGen/Pass.h index 02427e4..968c854 100644 --- a/mlir/include/mlir/TableGen/Pass.h +++ b/mlir/include/mlir/TableGen/Pass.h @@ -94,6 +94,9 @@ public: /// Return the C++ constructor call to create an instance of this pass. StringRef getConstructor() const; + /// Return the dialects this pass needs to be registered. + ArrayRef<StringRef> getDependentDialects() const; + /// Return the options provided by this pass. ArrayRef<PassOption> getOptions() const; @@ -104,6 +107,7 @@ public: private: const llvm::Record *def; + std::vector<StringRef> dependentDialects; std::vector<PassOption> options; std::vector<PassStatistic> statistics; }; diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 7787805..3292d5e 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -162,6 +162,8 @@ def BufferPlacement : FunctionPass<"buffer-placement"> { }]; let constructor = "mlir::createBufferPlacementPass()"; + // TODO: this pass likely shouldn't depend on Linalg? + let dependentDialects = ["linalg::LinalgDialect"]; } def Canonicalizer : Pass<"canonicalize"> { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 4ccfb45..2417c1d 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -9,9 +9,11 @@ #include "mlir-c/IR.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" #include "mlir/IR/Module.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" +#include "mlir/InitAllDialects.h" #include "mlir/Parser.h" #include "llvm/Support/raw_ostream.h" @@ -89,12 +91,17 @@ private: /* ========================================================================== */ MlirContext mlirContextCreate() { - auto *context = new MLIRContext; + auto *context = new MLIRContext(false); return wrap(context); } void mlirContextDestroy(MlirContext context) { delete unwrap(context); } +void mlirContextLoadAllDialects(MlirContext context) { + registerAllDialects(unwrap(context)); + getGlobalDialectRegistry().loadAll(unwrap(context)); +} + /* ========================================================================== */ /* Location API. */ /* ========================================================================== */ diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp index 1ebf481..4267393 100644 --- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp @@ -16,6 +16,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" #include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Serialization.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp index 7b57854..0460d98 100644 --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h index 6da0bc8..7fa5a5a 100644 --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -12,11 +12,43 @@ #include "mlir/Pass/Pass.h" namespace mlir { +class AffineDialect; +class StandardOpsDialect; + +// Forward declaration from Dialect.h +template <typename ConcreteDialect> +void registerDialect(DialectRegistry ®istry); namespace gpu { +class GPUDialect; class GPUModuleOp; } // end namespace gpu +namespace LLVM { +class LLVMDialect; +class LLVMAVX512Dialect; +} // end namespace LLVM + +namespace NVVM { +class NVVMDialect; +} // end namespace NVVM + +namespace ROCDL { +class ROCDLDialect; +} // end namespace ROCDL + +namespace scf { +class SCFDialect; +} // end namespace scf + +namespace spirv { +class SPIRVDialect; +} // end namespace spirv + +namespace vector { +class VectorDialect; +} // end namespace vector + #define GEN_PASS_CLASSES #include "mlir/Conversion/Passes.h.inc" diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index efe4a3c..6096703 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -125,7 +125,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options) - : llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()), + : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options) { assert(llvmDialect && "LLVM IR dialect is not registered"); if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout) diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp index 19643d2..a2e608d 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -14,6 +14,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/PatternMatch.h" diff --git a/mlir/lib/Dialect/Affine/Transforms/PassDetail.h b/mlir/lib/Dialect/Affine/Transforms/PassDetail.h index 3bae059..da8f7ac 100644 --- a/mlir/lib/Dialect/Affine/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/Affine/Transforms/PassDetail.h @@ -12,6 +12,16 @@ #include "mlir/Pass/Pass.h" namespace mlir { +// Forward declaration from Dialect.h +template <typename ConcreteDialect> +void registerDialect(DialectRegistry ®istry); + +namespace linalg { +class LinalgDialect; +} // end namespace linalg +namespace vector { +class VectorDialect; +} // end namespace vector #define GEN_PASS_CLASSES #include "mlir/Dialect/Affine/Passes.h.inc" diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 009699b..bf18c6c 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1224,6 +1224,7 @@ template <typename NamedStructuredOpType> static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result) { SmallVector<OpAsmParser::OperandType, 8> operandsInfo; + result.getContext()->getOrLoadDialect<StandardOpsDialect>(); // Optional attributes may be added. if (parser.parseOperandList(operandsInfo) || diff --git a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h index 7fa05ff..0415aeb 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h @@ -9,9 +9,18 @@ #ifndef DIALECT_LINALG_TRANSFORMS_PASSDETAIL_H_ #define DIALECT_LINALG_TRANSFORMS_PASSDETAIL_H_ +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/IR/Dialect.h" #include "mlir/Pass/Pass.h" namespace mlir { +// Forward declaration from Dialect.h +template <typename ConcreteDialect> +void registerDialect(DialectRegistry ®istry); + +namespace scf { +class SCFDialect; +} // end namespace scf #define GEN_PASS_CLASSES #include "mlir/Dialect/Linalg/Passes.h.inc" diff --git a/mlir/lib/Dialect/SCF/Transforms/PassDetail.h b/mlir/lib/Dialect/SCF/Transforms/PassDetail.h index 95f8636..6fa7f22 100644 --- a/mlir/lib/Dialect/SCF/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/SCF/Transforms/PassDetail.h @@ -12,6 +12,11 @@ #include "mlir/Pass/Pass.h" namespace mlir { +// Forward declaration from Dialect.h +template <typename ConcreteDialect> +void registerDialect(DialectRegistry ®istry); + +class AffineDialect; #define GEN_PASS_CLASSES #include "mlir/Dialect/SCF/Passes.h.inc" diff --git a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp index 435c7fe..a1971c3 100644 --- a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp +++ b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp @@ -517,7 +517,7 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) { SDBMDialect *dialect; } converter; - converter.dialect = affine.getContext()->getRegisteredDialect<SDBMDialect>(); + converter.dialect = affine.getContext()->getOrLoadDialect<SDBMDialect>(); if (auto result = converter.visit(affine)) return result; diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp index 7959183..2b18adb 100644 --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -259,7 +259,9 @@ int mlir::JitRunnerMain( } } - MLIRContext context; + MLIRContext context(/*loadAllDialects=*/false); + registerAllDialects(&context); + auto m = parseMLIRInput(options.inputFilename, &context); if (!m) { llvm::errs() << "could not parse the input IR\n"; diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index 555bb2b..f2f0a63 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -27,21 +27,25 @@ DialectAsmParser::~DialectAsmParser() {} //===----------------------------------------------------------------------===// /// Registry for all dialect allocation functions. -static llvm::ManagedStatic<llvm::MapVector<TypeID, DialectAllocatorFunction>> - dialectRegistry; - -void Dialect::registerDialectAllocator( - TypeID typeID, const DialectAllocatorFunction &function) { - assert(function && - "Attempting to register an empty dialect initialize function"); - dialectRegistry->insert({typeID, function}); +static llvm::ManagedStatic<DialectRegistry> dialectRegistry; +DialectRegistry &mlir::getGlobalDialectRegistry() { return *dialectRegistry; } + +Dialect *DialectRegistry::loadByName(StringRef name, MLIRContext *context) { + auto it = registry.find(std::string(name)); + if (it == registry.end()) + return nullptr; + return it->second.second(context); } -/// Registers all dialects and hooks from the global registries with the -/// specified MLIRContext. -void mlir::registerAllDialects(MLIRContext *context) { - for (const auto &it : *dialectRegistry) - it.second(context); +void DialectRegistry::insert(TypeID typeID, StringRef name, + DialectAllocatorFunction ctor) { + auto inserted = + registry.insert(std::make_pair(name, std::make_pair(typeID, ctor))); + if (!inserted.second && inserted.first->second.first != typeID) { + llvm::report_fatal_error( + "Trying to register different dialects for the same namespace: " + + name); + } } //===----------------------------------------------------------------------===// @@ -119,7 +123,7 @@ DialectInterface::~DialectInterface() {} DialectInterfaceCollectionBase::DialectInterfaceCollectionBase( MLIRContext *ctx, TypeID interfaceKind) { - for (auto *dialect : ctx->getRegisteredDialects()) { + for (auto *dialect : ctx->getLoadedDialects()) { if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) { interfaces.insert(interface); orderedInterfaces.push_back(interface); diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 0d66070..ed8e3db 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -31,10 +31,13 @@ #include "llvm/ADT/Twine.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/RWMutex.h" #include "llvm/Support/raw_ostream.h" #include <memory> +#define DEBUG_TYPE "mlircontext" + using namespace mlir; using namespace mlir::detail; @@ -275,7 +278,8 @@ public: /// This is a list of dialects that are created referring to this context. /// The MLIRContext owns the objects. - std::vector<std::unique_ptr<Dialect>> dialects; + DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects; + DialectRegistry dialectsRegistry; /// This is a mapping from operation name to AbstractOperation for registered /// operations. @@ -346,7 +350,7 @@ public: }; } // end namespace mlir -MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) { +MLIRContext::MLIRContext(bool loadAllDialects) : impl(new MLIRContextImpl()) { // Initialize values based on the command line flags if they were provided. if (clOptions.isConstructed()) { disableMultithreading(clOptions->disableThreading); @@ -355,8 +359,9 @@ MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) { } // Register dialects with this context. - getOrCreateDialect<BuiltinDialect>(); - registerAllDialects(this); + getOrLoadDialect<BuiltinDialect>(); + if (loadAllDialects) + loadAllGloballyRegisteredDialects(); // Initialize several common attributes and types to avoid the need to lock // the context when accessing them. @@ -438,54 +443,72 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; } // Dialect and Operation Registration //===----------------------------------------------------------------------===// +DialectRegistry &MLIRContext::getDialectRegistry() { + return impl->dialectsRegistry; +} + /// Return information about all registered IR dialects. -std::vector<Dialect *> MLIRContext::getRegisteredDialects() { +std::vector<Dialect *> MLIRContext::getLoadedDialects() { std::vector<Dialect *> result; - result.reserve(impl->dialects.size()); - for (auto &dialect : impl->dialects) - result.push_back(dialect.get()); + result.reserve(impl->loadedDialects.size()); + for (auto &dialect : impl->loadedDialects) { + result.push_back(dialect.second.get()); + } + llvm::sort(result, [](Dialect *lhs, Dialect *rhs) { + return lhs->getNamespace() < rhs->getNamespace(); + }); + return result; +} +std::vector<StringRef> MLIRContext::getAvailableDialects() { + std::vector<StringRef> result; + for (auto &dialect : impl->dialectsRegistry) + result.push_back(dialect.first); return result; } /// Get a registered IR dialect with the given namespace. If none is found, /// then return nullptr. -Dialect *MLIRContext::getRegisteredDialect(StringRef name) { +Dialect *MLIRContext::getLoadedDialect(StringRef name) { // Dialects are sorted by name, so we can use binary search for lookup. - auto it = llvm::lower_bound( - impl->dialects, name, - [](const auto &lhs, StringRef rhs) { return lhs->getNamespace() < rhs; }); - return (it != impl->dialects.end() && (*it)->getNamespace() == name) - ? (*it).get() - : nullptr; + auto it = impl->loadedDialects.find(name); + return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr; +} + +Dialect *MLIRContext::getOrLoadDialect(StringRef name) { + Dialect *dialect = getLoadedDialect(name); + if (dialect) + return dialect; + return impl->dialectsRegistry.loadByName(name, this); } /// Get a dialect for the provided namespace and TypeID: abort the program if a /// dialect exist for this namespace with different TypeID. Returns a pointer to /// the dialect owned by the context. Dialect * -MLIRContext::getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID, - function_ref<std::unique_ptr<Dialect>()> ctor) { +MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, + function_ref<std::unique_ptr<Dialect>()> ctor) { auto &impl = getImpl(); // Get the correct insertion position sorted by namespace. - auto insertPt = - llvm::lower_bound(impl.dialects, nullptr, - [&](const std::unique_ptr<Dialect> &lhs, - const std::unique_ptr<Dialect> &rhs) { - if (!lhs) - return dialectNamespace < rhs->getNamespace(); - return lhs->getNamespace() < dialectNamespace; - }); + std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace]; + + if (!dialect) { + LLVM_DEBUG(llvm::dbgs() + << "Load new dialect in Context" << dialectNamespace); + dialect = ctor(); + assert(dialect && "dialect ctor failed"); + return dialect.get(); + } // Abort if dialect with namespace has already been registered. - if (insertPt != impl.dialects.end() && - (*insertPt)->getNamespace() == dialectNamespace) { - if ((*insertPt)->getTypeID() == dialectID) - return insertPt->get(); + if (dialect->getTypeID() != dialectID) llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace + "' has already been registered"); - } - auto it = impl.dialects.insert(insertPt, ctor()); - return &**it; + + return dialect.get(); +} + +void MLIRContext::loadAllGloballyRegisteredDialects() { + getGlobalDialectRegistry().loadAll(this); } bool MLIRContext::allowsUnregisteredDialects() { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 152ed01..dce570a 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -214,7 +214,7 @@ Dialect *Operation::getDialect() { // If this operation hasn't been registered or doesn't have abstract // operation, try looking up the dialect name in the context. - return getContext()->getRegisteredDialect(getName().getDialect()); + return getContext()->getLoadedDialect(getName().getDialect()); } Region *Operation::getParentRegion() { diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp index b1aed88..4caf989 100644 --- a/mlir/lib/IR/Verifier.cpp +++ b/mlir/lib/IR/Verifier.cpp @@ -50,7 +50,7 @@ public: Dialect *getDialectForAttribute(const NamedAttribute &attr) { assert(attr.first.strref().contains('.') && "expected dialect attribute"); auto dialectNamePair = attr.first.strref().split('.'); - return ctx->getRegisteredDialect(dialectNamePair.first); + return ctx->getLoadedDialect(dialectNamePair.first); } private: @@ -218,7 +218,7 @@ LogicalResult OperationVerifier::verifyOperation(Operation &op) { auto it = dialectAllowsUnknownOps.find(dialectPrefix); if (it == dialectAllowsUnknownOps.end()) { // If the operation dialect is registered, query it directly. - if (auto *dialect = ctx->getRegisteredDialect(dialectPrefix)) + if (auto *dialect = ctx->getLoadedDialect(dialectPrefix)) it = dialectAllowsUnknownOps .try_emplace(dialectPrefix, dialect->allowsUnknownOperations()) .first; diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp index 1c1261e..37ee938 100644 --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -12,6 +12,7 @@ #include "Parser.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/Dialect.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/StandardTypes.h" #include "llvm/ADT/StringExtras.h" @@ -246,6 +247,11 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) { return emitError("duplicate key in dictionary attribute"); consumeToken(); + // Lazy load a dialect in the context if there is a possible namespace. + auto splitName = nameId->strref().split('.'); + if (!splitName.second.empty()) + getContext()->getOrLoadDialect(splitName.first); + // Try to parse the '=' for the attribute value. if (!consumeIf(Token::equal)) { // If there is no '=', we treat this as a unit attribute. @@ -817,7 +823,9 @@ Attribute Parser::parseOpaqueElementsAttr(Type attrType) { return (emitError("expected dialect namespace"), nullptr); auto name = getToken().getStringValue(); - auto *dialect = builder.getContext()->getRegisteredDialect(name); + // Lazy load a dialect in the context if there is a possible namespace. + Dialect *dialect = builder.getContext()->getOrLoadDialect(name); + // TODO: Allow for having an unknown dialect on an opaque // attribute. Otherwise, it can't be roundtripped without having the dialect // registered. diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp index 3b522a8..d45ddf0 100644 --- a/mlir/lib/Parser/DialectSymbolParser.cpp +++ b/mlir/lib/Parser/DialectSymbolParser.cpp @@ -526,7 +526,8 @@ Attribute Parser::parseExtendedAttr(Type type) { return Attribute(); // If we found a registered dialect, then ask it to parse the attribute. - if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { + if (Dialect *dialect = + builder.getContext()->getOrLoadDialect(dialectName)) { return parseSymbol<Attribute>( symbolData, state.context, state.symbols, [&](Parser &parser) { CustomDialectAsmParser customParser(symbolData, parser); @@ -563,7 +564,9 @@ Type Parser::parseExtendedType() { [&](StringRef dialectName, StringRef symbolData, llvm::SMLoc loc) -> Type { // If we found a registered dialect, then ask it to parse the type. - if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { + auto *dialect = state.context->getOrLoadDialect(dialectName); + + if (dialect) { return parseSymbol<Type>( symbolData, state.context, state.symbols, [&](Parser &parser) { CustomDialectAsmParser customParser(symbolData, parser); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 3a995a4..837b08c 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -12,6 +12,7 @@ #include "Parser.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/Dialect.h" #include "mlir/IR/Module.h" #include "mlir/IR/Verifier.h" #include "mlir/Parser.h" @@ -727,7 +728,7 @@ Operation *OperationParser::parseGenericOperation() { // Get location information for the operation. auto srcLocation = getEncodedSourceLocation(getToken().getLoc()); - auto name = getToken().getStringValue(); + std::string name = getToken().getStringValue(); if (name.empty()) return (emitError("empty operation name is invalid"), nullptr); if (name.find('\0') != StringRef::npos) @@ -737,6 +738,15 @@ Operation *OperationParser::parseGenericOperation() { OperationState result(srcLocation, name); + // Lazy load dialects in the context as needed. + if (!result.name.getAbstractOperation()) { + StringRef dialectName = StringRef(name).split('.').first; + if (!getContext()->getLoadedDialect(dialectName) && + getContext()->getOrLoadDialect(dialectName)) { + result.name = OperationName(name, getContext()); + } + } + // Parse the operand list. SmallVector<SSAUseInfo, 8> operandInfos; if (parseToken(Token::l_paren, "expected '(' to start operand list") || @@ -1442,17 +1452,28 @@ private: Operation * OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) { - auto opLoc = getToken().getLoc(); - auto opName = getTokenSpelling(); + llvm::SMLoc opLoc = getToken().getLoc(); + StringRef opName = getTokenSpelling(); auto *opDefinition = AbstractOperation::lookup(opName, getContext()); - if (!opDefinition && !opName.contains('.')) { - // If the operation name has no namespace prefix we treat it as a standard - // operation and prefix it with "std". - // TODO: Would it be better to just build a mapping of the registered - // operations in the standard dialect? - opDefinition = - AbstractOperation::lookup(Twine("std." + opName).str(), getContext()); + if (!opDefinition) { + if (opName.contains('.')) { + // This op has a dialect, we try to check if we can register it in the + // context on the fly. + StringRef dialectName = opName.split('.').first; + if (!getContext()->getLoadedDialect(dialectName) && + getContext()->getOrLoadDialect(dialectName)) { + opDefinition = AbstractOperation::lookup(opName, getContext()); + } + } else { + // If the operation name has no namespace prefix we treat it as a standard + // operation and prefix it with "std". + // TODO: Would it be better to just build a mapping of the registered + // operations in the standard dialect? + if (getContext()->getOrLoadDialect("std")) + opDefinition = AbstractOperation::lookup(Twine("std." + opName).str(), + getContext()); + } } if (!opDefinition) { diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index b791bf4..9debd11 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -290,6 +290,13 @@ OpPassManager::pass_iterator OpPassManager::begin() { } OpPassManager::pass_iterator OpPassManager::end() { return impl->passes.end(); } +OpPassManager::const_pass_iterator OpPassManager::begin() const { + return impl->passes.begin(); +} +OpPassManager::const_pass_iterator OpPassManager::end() const { + return impl->passes.end(); +} + /// Run all of the passes in this manager over the current operation. LogicalResult OpPassManager::run(Operation *op, AnalysisManager am) { // Run each of the held passes. @@ -346,6 +353,16 @@ void OpPassManager::printAsTextualPipeline(raw_ostream &os) { ::printAsTextualPipeline(impl->passes, os); } +static void registerDialectsForPipeline(const OpPassManager &pm, + DialectRegistry &dialects) { + for (const Pass &pass : pm.getPasses()) + pass.getDependentDialects(dialects); +} + +void OpPassManager::getDependentDialects(DialectRegistry &dialects) const { + registerDialectsForPipeline(*this, dialects); +} + //===----------------------------------------------------------------------===// // OpToOpPassAdaptor //===----------------------------------------------------------------------===// @@ -378,6 +395,11 @@ OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) { mgrs.emplace_back(std::move(mgr)); } +void OpToOpPassAdaptor::getDependentDialects(DialectRegistry &dialects) const { + for (auto &pm : mgrs) + pm.getDependentDialects(dialects); +} + /// Merge the current pass adaptor into given 'rhs'. void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) { for (auto &pm : mgrs) { @@ -721,6 +743,11 @@ LogicalResult PassManager::run(ModuleOp module) { // pipeline. getImpl().coalesceAdjacentAdaptorPasses(); + // Register all dialects for the current pipeline. + DialectRegistry dependent_dialects; + getDependentDialects(dependent_dialects); + dependent_dialects.loadAll(module.getContext()); + // Construct an analysis manager for the pipeline. ModuleAnalysisManager am(module, instrumentor.get()); diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h index 2342a1a..f69701d 100644 --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -43,6 +43,10 @@ public: /// Returns the pass managers held by this adaptor. MutableArrayRef<OpPassManager> getPassManagers() { return mgrs; } + /// Populate the set of dependent dialects for the passes in the current + /// adaptor. + void getDependentDialects(DialectRegistry &dialects) const override; + /// Return the async pass managers held by this parallel adaptor. MutableArrayRef<SmallVector<OpPassManager, 1>> getParallelPassManagers() { return asyncExecutors; diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp index 25e1970..e450fb3 100644 --- a/mlir/lib/Support/MlirOptMain.cpp +++ b/mlir/lib/Support/MlirOptMain.cpp @@ -75,13 +75,17 @@ static LogicalResult processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, + bool preloadDialectsInContext, const PassPipelineCLParser &passPipeline) { // Tell sourceMgr about this buffer, which is what the parser will pick up. SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); // Parse the input file. - MLIRContext context; + MLIRContext context(/*loadAllDialects=*/false); + registerAllDialects(&context); + if (preloadDialectsInContext) + context.getDialectRegistry().loadAll(&context); context.allowUnregisteredDialects(allowUnregisteredDialects); context.printOpOnDiagnostic(!verifyDiagnostics); @@ -111,7 +115,8 @@ LogicalResult mlir::MlirOptMain(raw_ostream &os, const PassPipelineCLParser &passPipeline, bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, - bool allowUnregisteredDialects) { + bool allowUnregisteredDialects, + bool preloadDialectsInContext) { // The split-input-file mode is a very specific mode that slices the file // up into small pieces and checks each independently. if (splitInputFile) @@ -120,10 +125,11 @@ LogicalResult mlir::MlirOptMain(raw_ostream &os, [&](std::unique_ptr<MemoryBuffer> chunkBuffer, raw_ostream &os) { return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics, verifyPasses, allowUnregisteredDialects, - passPipeline); + preloadDialectsInContext, passPipeline); }, os); return processBuffer(os, std::move(buffer), verifyDiagnostics, verifyPasses, - allowUnregisteredDialects, passPipeline); + allowUnregisteredDialects, preloadDialectsInContext, + passPipeline); } diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp index 6af77e7..8aee067 100644 --- a/mlir/lib/TableGen/Dialect.cpp +++ b/mlir/lib/TableGen/Dialect.cpp @@ -16,6 +16,11 @@ using namespace mlir; using namespace mlir::tblgen; +Dialect::Dialect(const llvm::Record *def) : def(def) { + for (StringRef dialect : def->getValueAsListOfStrings("dependentDialects")) + dependentDialects.push_back(dialect); +} + StringRef Dialect::getName() const { return def->getValueAsString("name"); } StringRef Dialect::getCppNamespace() const { @@ -46,6 +51,10 @@ StringRef Dialect::getDescription() const { return getAsStringOrEmpty(*def, "description"); } +ArrayRef<StringRef> Dialect::getDependentDialects() const { + return dependentDialects; +} + llvm::Optional<StringRef> Dialect::getExtraClassDeclaration() const { auto value = def->getValueAsString("extraClassDeclaration"); return value.empty() ? llvm::Optional<StringRef>() : value; diff --git a/mlir/lib/TableGen/Pass.cpp b/mlir/lib/TableGen/Pass.cpp index 4bc46b6..f961806 100644 --- a/mlir/lib/TableGen/Pass.cpp +++ b/mlir/lib/TableGen/Pass.cpp @@ -69,6 +69,8 @@ Pass::Pass(const llvm::Record *def) : def(def) { options.push_back(PassOption(init)); for (auto *init : def->getValueAsListOfDefs("statistics")) statistics.push_back(PassStatistic(init)); + for (StringRef dialect : def->getValueAsListOfStrings("dependentDialects")) + dependentDialects.push_back(dialect); } StringRef Pass::getArgument() const { @@ -88,6 +90,9 @@ StringRef Pass::getDescription() const { StringRef Pass::getConstructor() const { return def->getValueAsString("constructor"); } +ArrayRef<StringRef> Pass::getDependentDialects() const { + return dependentDialects; +} ArrayRef<PassOption> Pass::getOptions() const { return options; } diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp index 470044b..1d01569 100644 --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -836,6 +836,7 @@ LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) { OwningModuleRef mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule, MLIRContext *context) { + context->getOrLoadDialect<LLVMDialect>(); OwningModuleRef module(ModuleOp::create( FileLineColLoc::get("", /*line=*/0, /*column=*/0, context))); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 215c191..027422b 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -302,8 +302,7 @@ ModuleTranslation::ModuleTranslation(Operation *module, : mlirModule(module), llvmModule(std::move(llvmModule)), debugTranslation( std::make_unique<DebugTranslation>(module, *this->llvmModule)), - ompDialect( - module->getContext()->getRegisteredDialect<omp::OpenMPDialect>()), + ompDialect(module->getContext()->getOrLoadDialect<omp::OpenMPDialect>()), typeTranslator(this->llvmModule->getContext()) { assert(satisfiesLLVMModule(mlirModule) && "mlirModule should honor LLVM's module semantics."); @@ -944,7 +943,7 @@ ModuleTranslation::lookupValues(ValueRange values) { std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule( Operation *m, llvm::LLVMContext &llvmContext, StringRef name) { - auto *dialect = m->getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); + auto *dialect = m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>(); assert(dialect && "LLVM dialect must be registered"); auto llvmModule = std::make_unique<llvm::Module>(name, llvmContext); diff --git a/mlir/lib/Transforms/PassDetail.h b/mlir/lib/Transforms/PassDetail.h index c6f7e22..220ed1a 100644 --- a/mlir/lib/Transforms/PassDetail.h +++ b/mlir/lib/Transforms/PassDetail.h @@ -12,6 +12,13 @@ #include "mlir/Pass/Pass.h" namespace mlir { +// Forward declaration from Dialect.h +template <typename ConcreteDialect> +void registerDialect(DialectRegistry ®istry); + +namespace linalg { +class LinalgDialect; +} // end namespace linalg #define GEN_PASS_CLASSES #include "mlir/Transforms/Passes.h.inc" diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c index d6ab351..df2d32f 100644 --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -243,6 +243,7 @@ static void printFirstOfEach(MlirOperation operation) { int main() { mlirRegisterAllDialects(); MlirContext ctx = mlirContextCreate(); + mlirContextLoadAllDialects(ctx); MlirLocation location = mlirLocationUnknownGet(ctx); MlirModule moduleOp = makeAdd(ctx, location); diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index 3fcfcf2..e5766f0 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -36,16 +36,16 @@ using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; static MLIRContext &globalContext() { - static bool init_once = []() { - registerDialect<AffineDialect>(); - registerDialect<linalg::LinalgDialect>(); - registerDialect<scf::SCFDialect>(); - registerDialect<StandardOpsDialect>(); - registerDialect<vector::VectorDialect>(); + static thread_local MLIRContext context(/*loadAllDialects=*/false); + static thread_local bool init_once = [&]() { + context.getOrLoadDialect<AffineDialect>(); + context.getOrLoadDialect<scf::SCFDialect>(); + context.getOrLoadDialect<linalg::LinalgDialect>(); + context.getOrLoadDialect<StandardOpsDialect>(); + context.getOrLoadDialect<vector::VectorDialect>(); return true; }(); (void)init_once; - static thread_local MLIRContext context; context.allowUnregisteredDialects(); return context; } diff --git a/mlir/test/SDBM/sdbm-api-test.cpp b/mlir/test/SDBM/sdbm-api-test.cpp index 0b58e29..ddefc52 100644 --- a/mlir/test/SDBM/sdbm-api-test.cpp +++ b/mlir/test/SDBM/sdbm-api-test.cpp @@ -19,18 +19,19 @@ using namespace mlir; -// Load the SDBM dialect -static DialectRegistration<SDBMDialect> SDBMRegistration; static MLIRContext *ctx() { - static thread_local MLIRContext context; + static thread_local MLIRContext context(/*loadAllDialects=*/false); + static thread_local bool once = + (context.getOrLoadDialect<SDBMDialect>(), true); + (void)once; return &context; } static SDBMDialect *dialect() { static thread_local SDBMDialect *d = nullptr; if (!d) { - d = ctx()->getRegisteredDialect<SDBMDialect>(); + d = ctx()->getOrLoadDialect<SDBMDialect>(); } return d; } diff --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp index a6719b0..cfac2dc 100644 --- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp +++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp @@ -14,6 +14,7 @@ #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Diagnostics.h" @@ -72,6 +73,9 @@ struct VectorizerTestPass : public PassWrapper<VectorizerTestPass, FunctionPass> { static constexpr auto kTestAffineMapOpName = "test_affine_map"; static constexpr auto kTestAffineMapAttrName = "affine_map"; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<vector::VectorDialect>(); + } void runOnFunction() override; void testVectorShapeRatio(llvm::raw_ostream &outs); diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp index 0c1069f..03c425d 100644 --- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp +++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp @@ -30,7 +30,7 @@ void PrintOpAvailability::runOnFunction() { auto f = getFunction(); llvm::outs() << f.getName() << "\n"; - Dialect *spvDialect = getContext().getRegisteredDialect("spv"); + Dialect *spvDialect = getContext().getLoadedDialect("spv"); f.getOperation()->walk([&](Operation *op) { if (op->getDialect() != spvDialect) diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index f2a17a9..be5d799 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -768,6 +768,10 @@ struct TestTypeConversionProducer struct TestTypeConversionDriver : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<TestDialect>(); + } + void runOnOperation() override { // Initialize the type converter. TypeConverter converter; diff --git a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp index c043d0f..0c72b6c 100644 --- a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp +++ b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/GPU/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -19,6 +20,9 @@ using namespace mlir; namespace { struct TestAllReduceLoweringPass : public PassWrapper<TestAllReduceLoweringPass, OperationPass<ModuleOp>> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<StandardOpsDialect>(); + } void runOnOperation() override { OwningRewritePatternList patterns; populateGpuRewritePatterns(&getContext(), patterns); diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp index 5ad441aa..6cc0924 100644 --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -116,6 +116,10 @@ struct TestBufferPlacementPreparationPass patterns->insert<GenericOpConverter>(context, placer, converter); } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<linalg::LinalgDialect>(); + } + void runOnOperation() override { MLIRContext &context = this->getContext(); ConversionTarget target(context); diff --git a/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp b/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp index 08862dd..339ec1a 100644 --- a/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp +++ b/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp @@ -13,6 +13,9 @@ #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/MemoryPromotion.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/Pass/Pass.h" @@ -26,6 +29,11 @@ namespace { class TestGpuMemoryPromotionPass : public PassWrapper<TestGpuMemoryPromotionPass, OperationPass<gpu::GPUFuncOp>> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<StandardOpsDialect>(); + registry.insert<scf::SCFDialect>(); + } + void runOnOperation() override { gpu::GPUFuncOp op = getOperation(); for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) { diff --git a/mlir/test/lib/Transforms/TestLinalgHoisting.cpp b/mlir/test/lib/Transforms/TestLinalgHoisting.cpp index d1e478f..5d4031f 100644 --- a/mlir/test/lib/Transforms/TestLinalgHoisting.cpp +++ b/mlir/test/lib/Transforms/TestLinalgHoisting.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Pass/Pass.h" @@ -22,6 +23,9 @@ struct TestLinalgHoisting : public PassWrapper<TestLinalgHoisting, FunctionPass> { TestLinalgHoisting() = default; TestLinalgHoisting(const TestLinalgHoisting &pass) {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<AffineDialect>(); + } void runOnFunction() override; diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp index f6c1160..b6f3cb4 100644 --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -30,6 +31,14 @@ struct TestLinalgTransforms TestLinalgTransforms() = default; TestLinalgTransforms(const TestLinalgTransforms &pass) {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<AffineDialect>(); + registry.insert<scf::SCFDialect>(); + registry.insert<StandardOpsDialect>(); + registry.insert<vector::VectorDialect>(); + registry.insert<gpu::GPUDialect>(); + } + void runOnFunction() override; Option<bool> testPatterns{*this, "test-patterns", diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp index 9da3156..24e7a8c 100644 --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -8,6 +8,9 @@ #include <type_traits> +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorTransforms.h" @@ -128,6 +131,13 @@ struct TestVectorTransferFullPartialSplitPatterns TestVectorTransferFullPartialSplitPatterns() = default; TestVectorTransferFullPartialSplitPatterns( const TestVectorTransferFullPartialSplitPatterns &pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<AffineDialect>(); + registry.insert<linalg::LinalgDialect>(); + registry.insert<scf::SCFDialect>(); + } + Option<bool> useLinalgOps{ *this, "use-linalg-copy", llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir index f99a68d..4cf6ea9 100644 --- a/mlir/test/mlir-opt/commandline.mlir +++ b/mlir/test/mlir-opt/commandline.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt --show-dialects | FileCheck %s -// CHECK: Registered Dialects: +// CHECK: Available Dialects: // CHECK: affine // CHECK: gpu // CHECK: linalg diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp index 12e6aee..92efef6 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1703,7 +1703,7 @@ int main(int argc, char **argv) { if (testEmitIncludeTdHeader) output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\""; - MLIRContext context; + MLIRContext context(/*loadAllDialects=*/false); llvm::SourceMgr mgr; mgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc()); Parser parser(mgr, &context); diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index efcb328..53ea4da 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -175,11 +175,10 @@ int main(int argc, char **argv) { cl::ParseCommandLineOptions(argc, argv, "MLIR modular optimizer driver\n"); if(showDialects) { - llvm::outs() << "Registered Dialects:\n"; - MLIRContext context; - for(Dialect *dialect : context.getRegisteredDialects()) { - llvm::outs() << dialect->getNamespace() << "\n"; - } + MLIRContext context(false); + registerAllDialects(&context); + llvm::outs() << "Available Dialects:\n"; + interleave(context.getAvailableDialects(), llvm::outs(), "\n"); return 0; } diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp index 13421c4..797ecd7 100644 --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -61,11 +61,14 @@ filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) { /// /// {0}: The name of the dialect class. /// {1}: The dialect namespace. +/// {2}: initialization code that is emitted in the ctor body before calling +/// initialize() static const char *const dialectDeclBeginStr = R"( class {0} : public ::mlir::Dialect { explicit {0}(::mlir::MLIRContext *context) : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{ + {2} initialize(); } void initialize(); @@ -74,6 +77,12 @@ public: static ::llvm::StringRef getDialectNamespace() { return "{1}"; } )"; +/// Registration for a single dependent dialect: to be inserted in the ctor +/// above for each dependent dialect. +const char *const dialectRegistrationTemplate = R"( + getContext()->getOrLoadDialect<{0}>(); +)"; + /// The code block for the attribute parser/printer hooks. static const char *const attrParserDecl = R"( /// Parse an attribute registered to this dialect. @@ -136,9 +145,18 @@ static void emitDialectDecl(Dialect &dialect, iterator_range<DialectFilterIterator> dialectAttrs, iterator_range<DialectFilterIterator> dialectTypes, raw_ostream &os) { + /// Build the list of dependent dialects + std::string dependentDialectRegistrations; + { + llvm::raw_string_ostream dialects_os(dependentDialectRegistrations); + for (StringRef dependentDialect : dialect.getDependentDialects()) + dialects_os << llvm::formatv(dialectRegistrationTemplate, + dependentDialect); + } // Emit the start of the decl. std::string cppName = dialect.getCppClassName(); - os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName()); + os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(), + dependentDialectRegistrations); // Check for any attributes/types registered to this dialect. If there are, // add the hooks for parsing/printing. diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp index c2dcdb8..b6f5f3f 100644 --- a/mlir/tools/mlir-tblgen/PassGen.cpp +++ b/mlir/tools/mlir-tblgen/PassGen.cpp @@ -36,6 +36,7 @@ static llvm::cl::opt<std::string> /// {0}: The def name of the pass record. /// {1}: The base class for the pass. /// {2): The command line argument for the pass. +/// {3}: The dependent dialects registration. const char *const passDeclBegin = R"( //===----------------------------------------------------------------------===// // {0} @@ -63,9 +64,20 @@ public: return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); } + /// Return the dialect that must be loaded in the context before this pass. + void getDependentDialects(::mlir::DialectRegistry ®istry) const override { + {3} + } + protected: )"; +/// Registration for a single dependent dialect, to be inserted for each +/// dependent dialect in the `getDependentDialects` above. +const char *const dialectRegistrationTemplate = R"( + registry.insert<{0}>(); +)"; + /// Emit the declarations for each of the pass options. static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) { for (const PassOption &opt : pass.getOptions()) { @@ -94,8 +106,15 @@ static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) { static void emitPassDecl(const Pass &pass, raw_ostream &os) { StringRef defName = pass.getDef()->getName(); + std::string dependentDialectRegistrations; + { + llvm::raw_string_ostream dialects_os(dependentDialectRegistrations); + for (StringRef dependentDialect : pass.getDependentDialects()) + dialects_os << llvm::formatv(dialectRegistrationTemplate, + dependentDialect); + } os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(), - pass.getArgument()); + pass.getArgument(), dependentDialectRegistrations); emitPassOptionDecls(pass, os); emitPassStatisticDecls(pass, os); os << "};\n"; diff --git a/mlir/tools/mlir-translate/mlir-translate.cpp b/mlir/tools/mlir-translate/mlir-translate.cpp index 914bd34..0d67286 100644 --- a/mlir/tools/mlir-translate/mlir-translate.cpp +++ b/mlir/tools/mlir-translate/mlir-translate.cpp @@ -88,7 +88,8 @@ int main(int argc, char **argv) { // Processes the memory buffer with a new MLIRContext. auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer, raw_ostream &os) { - MLIRContext context; + MLIRContext context(false); + registerAllDialects(&context); context.allowUnregisteredDialects(); context.printOpOnDiagnostic(!verifyDiagnostics); llvm::SourceMgr sourceMgr; diff --git a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp index 97c94a5..bae95e1 100644 --- a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp +++ b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp @@ -17,9 +17,6 @@ using namespace mlir; using namespace mlir::quant; -// Load the quant dialect -static DialectRegistration<QuantizationDialect> QuantOpsRegistration; - namespace { // Test UniformQuantizedValueConverter converts all APFloat to a magic number 5. @@ -78,7 +75,8 @@ UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) { } TEST(QuantizationUtilsTest, convertFloatAttrUniform) { - MLIRContext ctx; + MLIRContext ctx(/*loadAllDialects=*/false); + ctx.getOrLoadDialect<QuantizationDialect>(); IntegerType convertedType = IntegerType::get(8, &ctx); auto quantizedType = getTestQuantizedType(convertedType, &ctx); TestUniformQuantizedValueConverter converter(quantizedType); @@ -95,7 +93,8 @@ TEST(QuantizationUtilsTest, convertFloatAttrUniform) { } TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) { - MLIRContext ctx; + MLIRContext ctx(/*loadAllDialects=*/false); + ctx.getOrLoadDialect<QuantizationDialect>(); IntegerType convertedType = IntegerType::get(8, &ctx); auto quantizedType = getTestQuantizedType(convertedType, &ctx); TestUniformQuantizedValueConverter converter(quantizedType); @@ -119,7 +118,8 @@ TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) { } TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) { - MLIRContext ctx; + MLIRContext ctx(/*loadAllDialects=*/false); + ctx.getOrLoadDialect<QuantizationDialect>(); IntegerType convertedType = IntegerType::get(8, &ctx); auto quantizedType = getTestQuantizedType(convertedType, &ctx); TestUniformQuantizedValueConverter converter(quantizedType); @@ -143,7 +143,8 @@ TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) { } TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) { - MLIRContext ctx; + MLIRContext ctx(/*loadAllDialects=*/false); + ctx.getOrLoadDialect<QuantizationDialect>(); IntegerType convertedType = IntegerType::get(8, &ctx); auto quantizedType = getTestQuantizedType(convertedType, &ctx); TestUniformQuantizedValueConverter converter(quantizedType); diff --git a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp index fe5632d..4aa2ffe 100644 --- a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp @@ -38,7 +38,8 @@ using ::testing::StrEq; /// diagnostic checking utilities. class DeserializationTest : public ::testing::Test { protected: - DeserializationTest() { + DeserializationTest() : context(/*loadAllDialects=*/false) { + context.getOrLoadDialect<mlir::spirv::SPIRVDialect>(); // Register a diagnostic handler to capture the diagnostic so that we can // check it later. context.getDiagEngine().registerHandler([&](Diagnostic &diag) { diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp index 3d57e55..cb89cd6 100644 --- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp @@ -36,7 +36,10 @@ using namespace mlir; class SerializationTest : public ::testing::Test { protected: - SerializationTest() { createModuleOp(); } + SerializationTest() : context(/*loadAllDialects=*/false) { + context.getOrLoadDialect<mlir::spirv::SPIRVDialect>(); + createModuleOp(); + } void createModuleOp() { OpBuilder builder(&context); diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp index df449a0..78f7dd5 100644 --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -32,7 +32,7 @@ static void testSplat(Type eltType, const EltTy &splatElt) { namespace { TEST(DenseSplatTest, BoolSplat) { - MLIRContext context; + MLIRContext context(false); IntegerType boolTy = IntegerType::get(1, &context); RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); @@ -57,7 +57,7 @@ TEST(DenseSplatTest, BoolSplat) { TEST(DenseSplatTest, LargeBoolSplat) { constexpr int64_t boolCount = 56; - MLIRContext context; + MLIRContext context(false); IntegerType boolTy = IntegerType::get(1, &context); RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy); @@ -80,7 +80,7 @@ TEST(DenseSplatTest, LargeBoolSplat) { } TEST(DenseSplatTest, BoolNonSplat) { - MLIRContext context; + MLIRContext context(false); IntegerType boolTy = IntegerType::get(1, &context); RankedTensorType shape = RankedTensorType::get({6}, boolTy); @@ -92,7 +92,7 @@ TEST(DenseSplatTest, BoolNonSplat) { TEST(DenseSplatTest, OddIntSplat) { // Test detecting a splat with an odd(non 8-bit) integer bitwidth. - MLIRContext context; + MLIRContext context(false); constexpr size_t intWidth = 19; IntegerType intTy = IntegerType::get(intWidth, &context); APInt value(intWidth, 10); @@ -101,7 +101,7 @@ TEST(DenseSplatTest, OddIntSplat) { } TEST(DenseSplatTest, Int32Splat) { - MLIRContext context; + MLIRContext context(false); IntegerType intTy = IntegerType::get(32, &context); int value = 64; @@ -109,7 +109,7 @@ TEST(DenseSplatTest, Int32Splat) { } TEST(DenseSplatTest, IntAttrSplat) { - MLIRContext context; + MLIRContext context(false); IntegerType intTy = IntegerType::get(85, &context); Attribute value = IntegerAttr::get(intTy, 109); @@ -117,7 +117,7 @@ TEST(DenseSplatTest, IntAttrSplat) { } TEST(DenseSplatTest, F32Splat) { - MLIRContext context; + MLIRContext context(false); FloatType floatTy = FloatType::getF32(&context); float value = 10.0; @@ -125,7 +125,7 @@ TEST(DenseSplatTest, F32Splat) { } TEST(DenseSplatTest, F64Splat) { - MLIRContext context; + MLIRContext context(false); FloatType floatTy = FloatType::getF64(&context); double value = 10.0; @@ -133,7 +133,7 @@ TEST(DenseSplatTest, F64Splat) { } TEST(DenseSplatTest, FloatAttrSplat) { - MLIRContext context; + MLIRContext context(false); FloatType floatTy = FloatType::getF32(&context); Attribute value = FloatAttr::get(floatTy, 10.0); @@ -141,7 +141,7 @@ TEST(DenseSplatTest, FloatAttrSplat) { } TEST(DenseSplatTest, BF16Splat) { - MLIRContext context; + MLIRContext context(false); FloatType floatTy = FloatType::getBF16(&context); Attribute value = FloatAttr::get(floatTy, 10.0); @@ -149,7 +149,7 @@ TEST(DenseSplatTest, BF16Splat) { } TEST(DenseSplatTest, StringSplat) { - MLIRContext context; + MLIRContext context(false); Type stringType = OpaqueType::get(Identifier::get("test", &context), "string", &context); StringRef value = "test-string"; @@ -157,7 +157,7 @@ TEST(DenseSplatTest, StringSplat) { } TEST(DenseSplatTest, StringAttrSplat) { - MLIRContext context; + MLIRContext context(false); Type stringType = OpaqueType::get(Identifier::get("test", &context), "string", &context); Attribute stringAttr = StringAttr::get("test-string", stringType); @@ -165,28 +165,28 @@ TEST(DenseSplatTest, StringAttrSplat) { } TEST(DenseComplexTest, ComplexFloatSplat) { - MLIRContext context; + MLIRContext context(false); ComplexType complexType = ComplexType::get(FloatType::getF32(&context)); std::complex<float> value(10.0, 15.0); testSplat(complexType, value); } TEST(DenseComplexTest, ComplexIntSplat) { - MLIRContext context; + MLIRContext context(false); ComplexType complexType = ComplexType::get(IntegerType::get(64, &context)); std::complex<int64_t> value(10, 15); testSplat(complexType, value); } TEST(DenseComplexTest, ComplexAPFloatSplat) { - MLIRContext context; + MLIRContext context(false); ComplexType complexType = ComplexType::get(FloatType::getF32(&context)); std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f)); testSplat(complexType, value); } TEST(DenseComplexTest, ComplexAPIntSplat) { - MLIRContext context; + MLIRContext context(false); ComplexType complexType = ComplexType::get(IntegerType::get(64, &context)); std::complex<APInt> value(APInt(64, 10), APInt(64, 15)); testSplat(complexType, value); diff --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp index bc389ce..c43fb77 100644 --- a/mlir/unittests/IR/DialectTest.cpp +++ b/mlir/unittests/IR/DialectTest.cpp @@ -26,12 +26,12 @@ struct AnotherTestDialect : public Dialect { }; TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) { - MLIRContext context; + MLIRContext context(false); // Registering a dialect with the same namespace twice should result in a // failure. - context.getOrCreateDialect<TestDialect>(); - ASSERT_DEATH(context.getOrCreateDialect<AnotherTestDialect>(), ""); + context.getOrLoadDialect<TestDialect>(); + ASSERT_DEATH(context.getOrLoadDialect<AnotherTestDialect>(), ""); } } // end namespace diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp index 95ddccc..9669330 100644 --- a/mlir/unittests/IR/OperationSupportTest.cpp +++ b/mlir/unittests/IR/OperationSupportTest.cpp @@ -25,7 +25,7 @@ static Operation *createOp(MLIRContext *context, namespace { TEST(OperandStorageTest, NonResizable) { - MLIRContext context; + MLIRContext context(false); Builder builder(&context); Operation *useOp = @@ -49,7 +49,7 @@ TEST(OperandStorageTest, NonResizable) { } TEST(OperandStorageTest, Resizable) { - MLIRContext context; + MLIRContext context(false); Builder builder(&context); Operation *useOp = @@ -77,7 +77,7 @@ TEST(OperandStorageTest, Resizable) { } TEST(OperandStorageTest, RangeReplace) { - MLIRContext context; + MLIRContext context(false); Builder builder(&context); Operation *useOp = @@ -113,7 +113,7 @@ TEST(OperandStorageTest, RangeReplace) { } TEST(OperandStorageTest, MutableRange) { - MLIRContext context; + MLIRContext context(false); Builder builder(&context); Operation *useOp = diff --git a/mlir/unittests/Pass/AnalysisManagerTest.cpp b/mlir/unittests/Pass/AnalysisManagerTest.cpp index a99df39..958cf43 100644 --- a/mlir/unittests/Pass/AnalysisManagerTest.cpp +++ b/mlir/unittests/Pass/AnalysisManagerTest.cpp @@ -24,7 +24,7 @@ struct OtherAnalysis { }; TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) { - MLIRContext context; + MLIRContext context(false); // Test fine grain invalidation of the module analysis manager. OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context))); @@ -45,7 +45,7 @@ TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) { } TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) { - MLIRContext context; + MLIRContext context(false); Builder builder(&context); // Create a function and a module. @@ -74,7 +74,7 @@ TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) { } TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) { - MLIRContext context; + MLIRContext context(false); Builder builder(&context); // Create a function and a module. @@ -117,7 +117,7 @@ struct CustomInvalidatingAnalysis { }; TEST(AnalysisManagerTest, CustomInvalidation) { - MLIRContext context; + MLIRContext context(false); Builder builder(&context); // Create a function and a module. diff --git a/mlir/unittests/SDBM/SDBMTest.cpp b/mlir/unittests/SDBM/SDBMTest.cpp index 61d6706..bbe87e3 100644 --- a/mlir/unittests/SDBM/SDBMTest.cpp +++ b/mlir/unittests/SDBM/SDBMTest.cpp @@ -17,18 +17,17 @@ using namespace mlir; -/// Load the SDBM dialect. -static DialectRegistration<SDBMDialect> SDBMRegistration; static MLIRContext *ctx() { - static thread_local MLIRContext context; + static thread_local MLIRContext context(false); + context.getOrLoadDialect<SDBMDialect>(); return &context; } static SDBMDialect *dialect() { static thread_local SDBMDialect *d = nullptr; if (!d) { - d = ctx()->getRegisteredDialect<SDBMDialect>(); + d = ctx()->getOrLoadDialect<SDBMDialect>(); } return d; } diff --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp index 3e3256e..46a37da 100644 --- a/mlir/unittests/TableGen/OpBuildGen.cpp +++ b/mlir/unittests/TableGen/OpBuildGen.cpp @@ -25,11 +25,16 @@ namespace mlir { // Test Fixture //===----------------------------------------------------------------------===// +static MLIRContext &getContext() { + static MLIRContext ctx(false); + ctx.getOrLoadDialect<TestDialect>(); + return ctx; +} /// Test fixture for providing basic utilities for testing. class OpBuildGenTest : public ::testing::Test { protected: OpBuildGenTest() - : ctx{}, builder(&ctx), loc(builder.getUnknownLoc()), + : ctx(getContext()), builder(&ctx), loc(builder.getUnknownLoc()), i32Ty(builder.getI32Type()), f32Ty(builder.getF32Type()), cstI32(builder.create<TableGenConstant>(loc, i32Ty)), cstF32(builder.create<TableGenConstant>(loc, f32Ty)), @@ -86,7 +91,7 @@ protected: } protected: - MLIRContext ctx; + MLIRContext &ctx; OpBuilder builder; Location loc; Type i32Ty; diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp index c58fedb..14b0abc 100644 --- a/mlir/unittests/TableGen/StructsGenTest.cpp +++ b/mlir/unittests/TableGen/StructsGenTest.cpp @@ -42,7 +42,7 @@ static test::TestStruct getTestStruct(mlir::MLIRContext *context) { /// Validates that test::TestStruct::classof correctly identifies a valid /// test::TestStruct. TEST(StructsGenTest, ClassofTrue) { - mlir::MLIRContext context; + mlir::MLIRContext context(false); auto structAttr = getTestStruct(&context); ASSERT_TRUE(test::TestStruct::classof(structAttr)); } |