aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--flang/unittests/Lower/OpenMPLoweringTest.cpp3
-rw-r--r--mlir/examples/standalone/standalone-opt/standalone-opt.cpp2
-rw-r--r--mlir/examples/toy/Ch2/toyc.cpp7
-rw-r--r--mlir/examples/toy/Ch3/toyc.cpp6
-rw-r--r--mlir/examples/toy/Ch4/toyc.cpp6
-rw-r--r--mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp4
-rw-r--r--mlir/examples/toy/Ch5/toyc.cpp6
-rw-r--r--mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp4
-rw-r--r--mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp4
-rw-r--r--mlir/examples/toy/Ch6/toyc.cpp6
-rw-r--r--mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp4
-rw-r--r--mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp4
-rw-r--r--mlir/examples/toy/Ch7/toyc.cpp6
-rw-r--r--mlir/include/mlir-c/IR.h6
-rw-r--r--mlir/include/mlir/Conversion/Passes.td26
-rw-r--r--mlir/include/mlir/Dialect/Affine/Passes.td1
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h1
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td5
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h1
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td1
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h1
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td1
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Passes.td8
-rw-r--r--mlir/include/mlir/Dialect/SCF/Passes.td1
-rw-r--r--mlir/include/mlir/IR/Dialect.h83
-rw-r--r--mlir/include/mlir/IR/FunctionSupport.h4
-rw-r--r--mlir/include/mlir/IR/MLIRContext.h54
-rw-r--r--mlir/include/mlir/IR/OpBase.td5
-rw-r--r--mlir/include/mlir/InitAllDialects.h43
-rw-r--r--mlir/include/mlir/Pass/Pass.h8
-rw-r--r--mlir/include/mlir/Pass/PassBase.td3
-rw-r--r--mlir/include/mlir/Pass/PassManager.h14
-rw-r--r--mlir/include/mlir/Support/MlirOptMain.h7
-rw-r--r--mlir/include/mlir/TableGen/Dialect.h8
-rw-r--r--mlir/include/mlir/TableGen/Pass.h4
-rw-r--r--mlir/include/mlir/Transforms/Passes.td2
-rw-r--r--mlir/lib/CAPI/IR/IR.cpp9
-rw-r--r--mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp1
-rw-r--r--mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp1
-rw-r--r--mlir/lib/Conversion/PassDetail.h32
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp2
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp1
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/PassDetail.h10
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp1
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/PassDetail.h9
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/PassDetail.h5
-rw-r--r--mlir/lib/Dialect/SDBM/SDBMExpr.cpp2
-rw-r--r--mlir/lib/ExecutionEngine/JitRunner.cpp4
-rw-r--r--mlir/lib/IR/Dialect.cpp32
-rw-r--r--mlir/lib/IR/MLIRContext.cpp87
-rw-r--r--mlir/lib/IR/Operation.cpp2
-rw-r--r--mlir/lib/IR/Verifier.cpp4
-rw-r--r--mlir/lib/Parser/AttributeParser.cpp10
-rw-r--r--mlir/lib/Parser/DialectSymbolParser.cpp7
-rw-r--r--mlir/lib/Parser/Parser.cpp41
-rw-r--r--mlir/lib/Pass/Pass.cpp27
-rw-r--r--mlir/lib/Pass/PassDetail.h4
-rw-r--r--mlir/lib/Support/MlirOptMain.cpp14
-rw-r--r--mlir/lib/TableGen/Dialect.cpp9
-rw-r--r--mlir/lib/TableGen/Pass.cpp5
-rw-r--r--mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp1
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp5
-rw-r--r--mlir/lib/Transforms/PassDetail.h7
-rw-r--r--mlir/test/CAPI/ir.c1
-rw-r--r--mlir/test/EDSC/builder-api-test.cpp14
-rw-r--r--mlir/test/SDBM/sdbm-api-test.cpp9
-rw-r--r--mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp4
-rw-r--r--mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp2
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp4
-rw-r--r--mlir/test/lib/Transforms/TestAllReduceLowering.cpp4
-rw-r--r--mlir/test/lib/Transforms/TestBufferPlacement.cpp4
-rw-r--r--mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp8
-rw-r--r--mlir/test/lib/Transforms/TestLinalgHoisting.cpp4
-rw-r--r--mlir/test/lib/Transforms/TestLinalgTransforms.cpp9
-rw-r--r--mlir/test/lib/Transforms/TestVectorTransforms.cpp10
-rw-r--r--mlir/test/mlir-opt/commandline.mlir2
-rw-r--r--mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp2
-rw-r--r--mlir/tools/mlir-opt/mlir-opt.cpp9
-rw-r--r--mlir/tools/mlir-tblgen/DialectGen.cpp20
-rw-r--r--mlir/tools/mlir-tblgen/PassGen.cpp21
-rw-r--r--mlir/tools/mlir-translate/mlir-translate.cpp3
-rw-r--r--mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp15
-rw-r--r--mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp3
-rw-r--r--mlir/unittests/Dialect/SPIRV/SerializationTest.cpp5
-rw-r--r--mlir/unittests/IR/AttributeTest.cpp32
-rw-r--r--mlir/unittests/IR/DialectTest.cpp6
-rw-r--r--mlir/unittests/IR/OperationSupportTest.cpp8
-rw-r--r--mlir/unittests/Pass/AnalysisManagerTest.cpp8
-rw-r--r--mlir/unittests/SDBM/SDBMTest.cpp7
-rw-r--r--mlir/unittests/TableGen/OpBuildGen.cpp9
-rw-r--r--mlir/unittests/TableGen/StructsGenTest.cpp2
91 files changed, 692 insertions, 214 deletions
diff --git a/flang/unittests/Lower/OpenMPLoweringTest.cpp b/flang/unittests/Lower/OpenMPLoweringTest.cpp
index ad6fe73..4c23845 100644
--- a/flang/unittests/Lower/OpenMPLoweringTest.cpp
+++ b/flang/unittests/Lower/OpenMPLoweringTest.cpp
@@ -15,8 +15,7 @@
class OpenMPLoweringTest : public testing::Test {
protected:
void SetUp() override {
- mlir::registerDialect<mlir::omp::OpenMPDialect>();
- mlir::registerAllDialects(&ctx);
+ ctx.getOrLoadDialect<mlir::omp::OpenMPDialect>();
mlirOpBuilder.reset(new mlir::OpBuilder(&ctx));
}
diff --git a/mlir/examples/standalone/standalone-opt/standalone-opt.cpp b/mlir/examples/standalone/standalone-opt/standalone-opt.cpp
index 5c99058..eb624b3 100644
--- a/mlir/examples/standalone/standalone-opt/standalone-opt.cpp
+++ b/mlir/examples/standalone/standalone-opt/standalone-opt.cpp
@@ -76,7 +76,7 @@ int main(int argc, char **argv) {
if (showDialects) {
mlir::MLIRContext context;
llvm::outs() << "Registered Dialects:\n";
- for (mlir::Dialect *dialect : context.getRegisteredDialects()) {
+ for (mlir::Dialect *dialect : context.getLoadedDialects()) {
llvm::outs() << dialect->getNamespace() << "\n";
}
return 0;
diff --git a/mlir/examples/toy/Ch2/toyc.cpp b/mlir/examples/toy/Ch2/toyc.cpp
index d0880ce..99232d8 100644
--- a/mlir/examples/toy/Ch2/toyc.cpp
+++ b/mlir/examples/toy/Ch2/toyc.cpp
@@ -68,10 +68,9 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
}
int dumpMLIR() {
- // Register our Dialect with MLIR.
- mlir::registerDialect<mlir::toy::ToyDialect>();
-
- mlir::MLIRContext context;
+ mlir::MLIRContext context(/*loadAllDialects=*/false);
+ // Load our Dialect in this MLIR Context.
+ context.getOrLoadDialect<mlir::toy::ToyDialect>();
// Handle '.toy' input to the compiler.
if (inputType != InputType::MLIR &&
diff --git a/mlir/examples/toy/Ch3/toyc.cpp b/mlir/examples/toy/Ch3/toyc.cpp
index f9d5631..d0430ce 100644
--- a/mlir/examples/toy/Ch3/toyc.cpp
+++ b/mlir/examples/toy/Ch3/toyc.cpp
@@ -102,10 +102,10 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
}
int dumpMLIR() {
- // Register our Dialect with MLIR.
- mlir::registerDialect<mlir::toy::ToyDialect>();
+ mlir::MLIRContext context(/*loadAllDialects=*/false);
+ // Load our Dialect in this MLIR Context.
+ context.getOrLoadDialect<mlir::toy::ToyDialect>();
- mlir::MLIRContext context;
mlir::OwningModuleRef module;
llvm::SourceMgr sourceMgr;
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
diff --git a/mlir/examples/toy/Ch4/toyc.cpp b/mlir/examples/toy/Ch4/toyc.cpp
index e11f35c..9f95887 100644
--- a/mlir/examples/toy/Ch4/toyc.cpp
+++ b/mlir/examples/toy/Ch4/toyc.cpp
@@ -103,10 +103,10 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
}
int dumpMLIR() {
- // Register our Dialect with MLIR.
- mlir::registerDialect<mlir::toy::ToyDialect>();
+ mlir::MLIRContext context(/*loadAllDialects=*/false);
+ // Load our Dialect in this MLIR Context.
+ context.getOrLoadDialect<mlir::toy::ToyDialect>();
- mlir::MLIRContext context;
mlir::OwningModuleRef module;
llvm::SourceMgr sourceMgr;
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index 3097681..1077fc9 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -256,6 +256,10 @@ struct TransposeOpLowering : public ConversionPattern {
namespace {
struct ToyToAffineLoweringPass
: public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<AffineDialect>();
+ registry.insert<StandardOpsDialect>();
+ }
void runOnFunction() final;
};
} // end anonymous namespace.
diff --git a/mlir/examples/toy/Ch5/toyc.cpp b/mlir/examples/toy/Ch5/toyc.cpp
index ed04969..16faac0 100644
--- a/mlir/examples/toy/Ch5/toyc.cpp
+++ b/mlir/examples/toy/Ch5/toyc.cpp
@@ -106,10 +106,10 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
}
int dumpMLIR() {
- // Register our Dialect with MLIR.
- mlir::registerDialect<mlir::toy::ToyDialect>();
+ mlir::MLIRContext context(/*loadAllDialects=*/false);
+ // Load our Dialect in this MLIR Context.
+ context.getOrLoadDialect<mlir::toy::ToyDialect>();
- mlir::MLIRContext context;
mlir::OwningModuleRef module;
llvm::SourceMgr sourceMgr;
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
index cac3415..9ff9eb4 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -255,6 +255,10 @@ struct TransposeOpLowering : public ConversionPattern {
namespace {
struct ToyToAffineLoweringPass
: public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<AffineDialect>();
+ registry.insert<StandardOpsDialect>();
+ }
void runOnFunction() final;
};
} // end anonymous namespace.
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index 74b32dc..8020fb3 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -159,6 +159,10 @@ private:
namespace {
struct ToyToLLVMLoweringPass
: public PassWrapper<ToyToLLVMLoweringPass, OperationPass<ModuleOp>> {
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<LLVM::LLVMDialect>();
+ registry.insert<scf::SCFDialect>();
+ }
void runOnOperation() final;
};
} // end anonymous namespace
diff --git a/mlir/examples/toy/Ch6/toyc.cpp b/mlir/examples/toy/Ch6/toyc.cpp
index bdcdf1a..9504a38 100644
--- a/mlir/examples/toy/Ch6/toyc.cpp
+++ b/mlir/examples/toy/Ch6/toyc.cpp
@@ -255,10 +255,10 @@ int main(int argc, char **argv) {
// If we aren't dumping the AST, then we are compiling with/to MLIR.
- // Register our Dialect with MLIR.
- mlir::registerDialect<mlir::toy::ToyDialect>();
+ mlir::MLIRContext context(/*loadAllDialects=*/false);
+ // Load our Dialect in this MLIR Context.
+ context.getOrLoadDialect<mlir::toy::ToyDialect>();
- mlir::MLIRContext context;
mlir::OwningModuleRef module;
if (int error = loadAndProcessMLIR(context, module))
return error;
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index 3097681..1077fc9 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -256,6 +256,10 @@ struct TransposeOpLowering : public ConversionPattern {
namespace {
struct ToyToAffineLoweringPass
: public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<AffineDialect>();
+ registry.insert<StandardOpsDialect>();
+ }
void runOnFunction() final;
};
} // end anonymous namespace.
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index 74b32dc..8020fb3 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -159,6 +159,10 @@ private:
namespace {
struct ToyToLLVMLoweringPass
: public PassWrapper<ToyToLLVMLoweringPass, OperationPass<ModuleOp>> {
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<LLVM::LLVMDialect>();
+ registry.insert<scf::SCFDialect>();
+ }
void runOnOperation() final;
};
} // end anonymous namespace
diff --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp
index c1cc207..cb3b455 100644
--- a/mlir/examples/toy/Ch7/toyc.cpp
+++ b/mlir/examples/toy/Ch7/toyc.cpp
@@ -256,10 +256,10 @@ int main(int argc, char **argv) {
// If we aren't dumping the AST, then we are compiling with/to MLIR.
- // Register our Dialect with MLIR.
- mlir::registerDialect<mlir::toy::ToyDialect>();
+ mlir::MLIRContext context(/*loadAllDialects=*/false);
+ // Load our Dialect in this MLIR Context.
+ context.getOrLoadDialect<mlir::toy::ToyDialect>();
- mlir::MLIRContext context;
mlir::OwningModuleRef module;
if (int error = loadAndProcessMLIR(context, module))
return error;
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 6b5be2d..f9ec4d1 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -90,6 +90,12 @@ MlirContext mlirContextCreate();
/** Takes an MLIR context owned by the caller and destroys it. */
void mlirContextDestroy(MlirContext context);
+/** Load all the globally registered dialects in the provided context.
+ * TODO: remove the concept of globally registered dialect by exposing the
+ * DialectRegistry.
+ */
+void mlirContextLoadAllDialects(MlirContext context);
+
/*============================================================================*/
/* Location API. */
/*============================================================================*/
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 4d4fe06..0c40bb3 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -66,6 +66,11 @@ def ConvertAffineToStandard : Pass<"lower-affine"> {
`affine.apply`.
}];
let constructor = "mlir::createLowerAffinePass()";
+ let dependentDialects = [
+ "scf::SCFDialect",
+ "StandardOpsDialect",
+ "vector::VectorDialect"
+ ];
}
//===----------------------------------------------------------------------===//
@@ -76,6 +81,7 @@ def ConvertAVX512ToLLVM : Pass<"convert-avx512-to-llvm", "ModuleOp"> {
let summary = "Convert the operations from the avx512 dialect into the LLVM "
"dialect";
let constructor = "mlir::createConvertAVX512ToLLVMPass()";
+ let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMAVX512Dialect"];
}
//===----------------------------------------------------------------------===//
@@ -98,6 +104,7 @@ def GpuToLLVMConversionPass : Pass<"gpu-to-llvm", "ModuleOp"> {
def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
let summary = "Generate NVVM operations for gpu operations";
let constructor = "mlir::createLowerGpuOpsToNVVMOpsPass()";
+ let dependentDialects = ["NVVM::NVVMDialect"];
let options = [
Option<"indexBitwidth", "index-bitwidth", "unsigned",
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
@@ -112,6 +119,7 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
let summary = "Generate ROCDL operations for gpu operations";
let constructor = "mlir::createLowerGpuOpsToROCDLOpsPass()";
+ let dependentDialects = ["ROCDL::ROCDLDialect"];
let options = [
Option<"indexBitwidth", "index-bitwidth", "unsigned",
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
@@ -126,6 +134,7 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
let summary = "Convert GPU dialect to SPIR-V dialect";
let constructor = "mlir::createConvertGPUToSPIRVPass()";
+ let dependentDialects = ["spirv::SPIRVDialect"];
}
//===----------------------------------------------------------------------===//
@@ -136,6 +145,7 @@ def ConvertGpuLaunchFuncToVulkanLaunchFunc
: Pass<"convert-gpu-launch-to-vulkan-launch", "ModuleOp"> {
let summary = "Convert gpu.launch_func to vulkanLaunch external call";
let constructor = "mlir::createConvertGpuLaunchFuncToVulkanLaunchFuncPass()";
+ let dependentDialects = ["spirv::SPIRVDialect"];
}
def ConvertVulkanLaunchFuncToVulkanCalls
@@ -143,6 +153,7 @@ def ConvertVulkanLaunchFuncToVulkanCalls
let summary = "Convert vulkanLaunch external call to Vulkan runtime external "
"calls";
let constructor = "mlir::createConvertVulkanLaunchFuncToVulkanCallsPass()";
+ let dependentDialects = ["LLVM::LLVMDialect"];
}
//===----------------------------------------------------------------------===//
@@ -153,6 +164,7 @@ def ConvertLinalgToLLVM : Pass<"convert-linalg-to-llvm", "ModuleOp"> {
let summary = "Convert the operations from the linalg dialect into the LLVM "
"dialect";
let constructor = "mlir::createConvertLinalgToLLVMPass()";
+ let dependentDialects = ["scf::SCFDialect", "LLVM::LLVMDialect"];
}
//===----------------------------------------------------------------------===//
@@ -163,6 +175,7 @@ def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
let summary = "Convert the operations from the linalg dialect into the "
"Standard dialect";
let constructor = "mlir::createConvertLinalgToStandardPass()";
+ let dependentDialects = ["StandardOpsDialect"];
}
//===----------------------------------------------------------------------===//
@@ -172,6 +185,7 @@ def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
def ConvertLinalgToSPIRV : Pass<"convert-linalg-to-spirv", "ModuleOp"> {
let summary = "Convert Linalg ops to SPIR-V ops";
let constructor = "mlir::createLinalgToSPIRVPass()";
+ let dependentDialects = ["spirv::SPIRVDialect"];
}
//===----------------------------------------------------------------------===//
@@ -182,6 +196,7 @@ def SCFToStandard : Pass<"convert-scf-to-std"> {
let summary = "Convert SCF dialect to Standard dialect, replacing structured"
" control flow with a CFG";
let constructor = "mlir::createLowerToCFGPass()";
+ let dependentDialects = ["StandardOpsDialect"];
}
//===----------------------------------------------------------------------===//
@@ -191,6 +206,7 @@ def SCFToStandard : Pass<"convert-scf-to-std"> {
def ConvertAffineForToGPU : FunctionPass<"convert-affine-for-to-gpu"> {
let summary = "Convert top-level AffineFor Ops to GPU kernels";
let constructor = "mlir::createAffineForToGPUPass()";
+ let dependentDialects = ["gpu::GPUDialect"];
let options = [
Option<"numBlockDims", "gpu-block-dims", "unsigned", /*default=*/"1u",
"Number of GPU block dimensions for mapping">,
@@ -202,6 +218,7 @@ def ConvertAffineForToGPU : FunctionPass<"convert-affine-for-to-gpu"> {
def ConvertParallelLoopToGpu : Pass<"convert-parallel-loops-to-gpu"> {
let summary = "Convert mapped scf.parallel ops to gpu launch operations";
let constructor = "mlir::createParallelLoopToGpuPass()";
+ let dependentDialects = ["AffineDialect", "gpu::GPUDialect"];
}
//===----------------------------------------------------------------------===//
@@ -212,6 +229,7 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
let summary = "Convert operations from the shape dialect into the standard "
"dialect";
let constructor = "mlir::createConvertShapeToStandardPass()";
+ let dependentDialects = ["StandardOpsDialect"];
}
//===----------------------------------------------------------------------===//
@@ -221,6 +239,7 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> {
let summary = "Convert operations from the shape dialect to the SCF dialect";
let constructor = "mlir::createConvertShapeToSCFPass()";
+ let dependentDialects = ["scf::SCFDialect"];
}
//===----------------------------------------------------------------------===//
@@ -230,6 +249,7 @@ def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> {
def ConvertSPIRVToLLVM : Pass<"convert-spirv-to-llvm", "ModuleOp"> {
let summary = "Convert SPIR-V dialect to LLVM dialect";
let constructor = "mlir::createConvertSPIRVToLLVMPass()";
+ let dependentDialects = ["LLVM::LLVMDialect"];
}
//===----------------------------------------------------------------------===//
@@ -264,6 +284,7 @@ def ConvertStandardToLLVM : Pass<"convert-std-to-llvm", "ModuleOp"> {
LLVM IR types.
}];
let constructor = "mlir::createLowerToLLVMPass()";
+ let dependentDialects = ["LLVM::LLVMDialect"];
let options = [
Option<"useAlignedAlloc", "use-aligned-alloc", "bool", /*default=*/"false",
"Use aligned_alloc in place of malloc for heap allocations">,
@@ -287,11 +308,13 @@ def ConvertStandardToLLVM : Pass<"convert-std-to-llvm", "ModuleOp"> {
def LegalizeStandardForSPIRV : Pass<"legalize-std-for-spirv"> {
let summary = "Legalize standard ops for SPIR-V lowering";
let constructor = "mlir::createLegalizeStdOpsForSPIRVLoweringPass()";
+ let dependentDialects = ["spirv::SPIRVDialect"];
}
def ConvertStandardToSPIRV : Pass<"convert-std-to-spirv", "ModuleOp"> {
let summary = "Convert Standard Ops to SPIR-V dialect";
let constructor = "mlir::createConvertStandardToSPIRVPass()";
+ let dependentDialects = ["spirv::SPIRVDialect"];
}
//===----------------------------------------------------------------------===//
@@ -302,6 +325,7 @@ def ConvertVectorToSCF : FunctionPass<"convert-vector-to-scf"> {
let summary = "Lower the operations from the vector dialect into the SCF "
"dialect";
let constructor = "mlir::createConvertVectorToSCFPass()";
+ let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
let options = [
Option<"fullUnroll", "full-unroll", "bool", /*default=*/"false",
"Perform full unrolling when converting vector transfers to SCF">,
@@ -316,6 +340,7 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
let summary = "Lower the operations from the vector dialect into the LLVM "
"dialect";
let constructor = "mlir::createConvertVectorToLLVMPass()";
+ let dependentDialects = ["LLVM::LLVMDialect"];
let options = [
Option<"reassociateFPReductions", "reassociate-fp-reductions",
"bool", /*default=*/"false",
@@ -331,6 +356,7 @@ def ConvertVectorToROCDL : Pass<"convert-vector-to-rocdl", "ModuleOp"> {
let summary = "Lower the operations from the vector dialect into the ROCDL "
"dialect";
let constructor = "mlir::createConvertVectorToROCDLPass()";
+ let dependentDialects = ["ROCDL::ROCDLDialect"];
}
#endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index 8106400..f43fabd 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -94,6 +94,7 @@ def AffineLoopUnrollAndJam : FunctionPass<"affine-loop-unroll-jam"> {
def AffineVectorize : FunctionPass<"affine-super-vectorize"> {
let summary = "Vectorize to a target independent n-D vector abstraction";
let constructor = "mlir::createSuperVectorizePass()";
+ let dependentDialects = ["vector::VectorDialect"];
let options = [
ListOption<"vectorSizes", "virtual-vector-size", "int64_t",
"Specify an n-D virtual vector size for vectorization",
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 04700f0..2f465f0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -15,6 +15,7 @@
#define MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index d21f5bc..2617a2d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -19,6 +19,11 @@ include "mlir/IR/OpBase.td"
def LLVM_Dialect : Dialect {
let name = "llvm";
let cppNamespace = "LLVM";
+
+ /// FIXME: at the moment this is a dependency of the translation to LLVM IR,
+ /// not really one of this dialect per-se.
+ let dependentDialects = [ "omp::OpenMPDialect" ];
+
let hasRegionArgAttrVerify = 1;
let extraClassDeclaration = [{
~LLVMDialect();
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index 86d437c..9cc5314 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -14,6 +14,7 @@
#ifndef MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_
#define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 5f022e32..0e5bc16 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -23,6 +23,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def NVVM_Dialect : Dialect {
let name = "nvvm";
let cppNamespace = "NVVM";
+ let dependentDialects = [ "LLVM::LLVMDialect" ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
index bf761c3..eb40373 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
@@ -22,6 +22,7 @@
#ifndef MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_
#define MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 0cd1169..f9aca5a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -23,6 +23,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def ROCDL_Dialect : Dialect {
let name = "rocdl";
let cppNamespace = "ROCDL";
+ let dependentDialects = [ "LLVM::LLVMDialect" ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 11f12ad..dcf4b5e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -30,17 +30,20 @@ def LinalgFusion : FunctionPass<"linalg-fusion"> {
def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> {
let summary = "Fuse operations on RankedTensorType in linalg dialect";
let constructor = "mlir::createLinalgFusionOfTensorOpsPass()";
+ let dependentDialects = ["AffineDialect"];
}
def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
let summary = "Lower the operations from the linalg dialect into affine "
"loops";
let constructor = "mlir::createConvertLinalgToAffineLoopsPass()";
+ let dependentDialects = ["AffineDialect"];
}
def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {
let summary = "Lower the operations from the linalg dialect into loops";
let constructor = "mlir::createConvertLinalgToLoopsPass()";
+ let dependentDialects = ["scf::SCFDialect", "AffineDialect"];
}
def LinalgOnTensorsToBuffers : Pass<"convert-linalg-on-tensors-to-buffers", "ModuleOp"> {
@@ -54,6 +57,7 @@ def LinalgLowerToParallelLoops
let summary = "Lower the operations from the linalg dialect into parallel "
"loops";
let constructor = "mlir::createConvertLinalgToParallelLoopsPass()";
+ let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
}
def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> {
@@ -70,6 +74,9 @@ def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> {
def LinalgTiling : FunctionPass<"linalg-tile"> {
let summary = "Tile operations in the linalg dialect";
let constructor = "mlir::createLinalgTilingPass()";
+ let dependentDialects = [
+ "AffineDialect", "scf::SCFDialect"
+ ];
let options = [
ListOption<"tileSizes", "linalg-tile-sizes", "int64_t",
"Test generation of dynamic promoted buffers",
@@ -86,6 +93,7 @@ def LinalgTilingToParallelLoops
"Test generation of dynamic promoted buffers",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
];
+ let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
}
#endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td
index 483d0ba..6f3cf0e 100644
--- a/mlir/include/mlir/Dialect/SCF/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Passes.td
@@ -36,6 +36,7 @@ def SCFParallelLoopTiling : FunctionPass<"parallel-loop-tiling"> {
"Factors to tile parallel loops by",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
];
+ let dependentDialects = ["AffineDialect"];
}
#endif // MLIR_DIALECT_SCF_PASSES
diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 4f9e4cb..d00d86db 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -16,6 +16,8 @@
#include "mlir/IR/OperationSupport.h"
#include "mlir/Support/TypeID.h"
+#include <map>
+
namespace mlir {
class DialectAsmParser;
class DialectAsmPrinter;
@@ -23,7 +25,7 @@ class DialectInterface;
class OpBuilder;
class Type;
-using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
+using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
/// Dialects are groups of MLIR operations and behavior associated with the
/// entire group. For example, hooks into other systems for constant folding,
@@ -212,30 +214,81 @@ private:
/// A collection of registered dialect interfaces.
DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces;
- /// Registers a specific dialect creation function with the global registry.
- /// Used through the registerDialect template.
- /// Registrations are deduplicated by dialect TypeID and only the first
- /// registration will be used.
- static void
- registerDialectAllocator(TypeID typeID,
- const DialectAllocatorFunction &function);
- template <typename ConcreteDialect>
friend void registerDialect();
friend class MLIRContext;
};
-/// Registers all dialects and hooks from the global registries with the
-/// specified MLIRContext.
+/// The DialectRegistry maps a dialect namespace to a constructor for the
+/// matching dialect.
+/// This allows for decoupling the list of dialects "available" from the
+/// dialects loaded in the Context. The parser in particular will lazily load
+/// dialects in in the Context as operations are encountered.
+class DialectRegistry {
+ using MapTy =
+ std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
+
+public:
+ template <typename ConcreteDialect> void insert() {
+ insert(TypeID::get<ConcreteDialect>(),
+ ConcreteDialect::getDialectNamespace(),
+ static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
+ // Just allocate the dialect, the context
+ // takes ownership of it.
+ return ctx->getOrLoadDialect<ConcreteDialect>();
+ })));
+ }
+
+ /// Add a new dialect constructor to the registry.
+ void insert(TypeID typeID, StringRef name, DialectAllocatorFunction ctor);
+
+ /// Load a dialect for this namespace in the provided context.
+ Dialect *loadByName(StringRef name, MLIRContext *context);
+
+ // Register all dialects available in the current registry with the registry
+ // in the provided context.
+ void appendTo(DialectRegistry &destination) {
+ for (const auto &name_and_registration_it : registry)
+ destination.insert(name_and_registration_it.second.first,
+ name_and_registration_it.first,
+ name_and_registration_it.second.second);
+ }
+ // Load all dialects available in the registry in the provided context.
+ void loadAll(MLIRContext *context) {
+ for (const auto &name_and_registration_it : registry)
+ name_and_registration_it.second.second(context);
+ }
+
+ MapTy::const_iterator begin() const { return registry.begin(); }
+ MapTy::const_iterator end() const { return registry.end(); }
+
+private:
+ MapTy registry;
+};
+
+/// Deprecated: this provides a global registry for convenience, while we're
+/// transitionning the registration mechanism to a stateless approach.
+DialectRegistry &getGlobalDialectRegistry();
+
+/// Registers all dialects from the global registries with the
+/// specified MLIRContext. This won't load the dialects in the context,
+/// but only make them available for lazy loading by name.
/// Note: This method is not thread-safe.
-void registerAllDialects(MLIRContext *context);
+inline void registerAllDialects(MLIRContext *context) {
+ getGlobalDialectRegistry().appendTo(context->getDialectRegistry());
+}
+
+/// Register and return the dialect with the given namespace in the provided
+/// context. Returns nullptr is there is no constructor registered for this
+/// dialect.
+inline Dialect *registerDialect(StringRef name, MLIRContext *context) {
+ return getGlobalDialectRegistry().loadByName(name, context);
+}
/// Utility to register a dialect. Client can register their dialect with the
/// global registry by calling registerDialect<MyDialect>();
/// Note: This method is not thread-safe.
template <typename ConcreteDialect> void registerDialect() {
- Dialect::registerDialectAllocator(
- TypeID::get<ConcreteDialect>(),
- [](MLIRContext *ctx) { ctx->getOrCreateDialect<ConcreteDialect>(); });
+ getGlobalDialectRegistry().insert<ConcreteDialect>();
}
/// DialectRegistration provides a global initializer that registers a Dialect
diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h
index 7e281f3..3d467cd 100644
--- a/mlir/include/mlir/IR/FunctionSupport.h
+++ b/mlir/include/mlir/IR/FunctionSupport.h
@@ -428,7 +428,7 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
if (!attr.first.strref().contains('.'))
return funcOp.emitOpError("arguments may only have dialect attributes");
auto dialectNamePair = attr.first.strref().split('.');
- if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) {
+ if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) {
if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0,
/*argIndex=*/i, attr)))
return failure();
@@ -444,7 +444,7 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
if (!attr.first.strref().contains('.'))
return funcOp.emitOpError("results may only have dialect attributes");
auto dialectNamePair = attr.first.strref().split('.');
- if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) {
+ if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) {
if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0,
/*resultIndex=*/i,
attr)))
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index 0192a8a..d406c30 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -19,10 +19,12 @@ namespace mlir {
class AbstractOperation;
class DiagnosticEngine;
class Dialect;
+class DialectRegistry;
class InFlightDiagnostic;
class Location;
class MLIRContextImpl;
class StorageUniquer;
+DialectRegistry &getGlobalDialectRegistry();
/// MLIRContext is the top-level object for a collection of MLIR modules. It
/// holds immortal uniqued objects like types, and the tables used to unique
@@ -34,34 +36,54 @@ class StorageUniquer;
///
class MLIRContext {
public:
- explicit MLIRContext();
+ /// Create a new Context.
+ /// The loadAllDialects parameters allows to load all dialects from the global
+ /// registry on Context construction. It is deprecated and will be removed
+ /// soon.
+ explicit MLIRContext(bool loadAllDialects = true);
~MLIRContext();
- /// Return information about all registered IR dialects.
- std::vector<Dialect *> getRegisteredDialects();
+ /// Return information about all IR dialects loaded in the context.
+ std::vector<Dialect *> getLoadedDialects();
+
+ /// Return the dialect registry associated with this context.
+ DialectRegistry &getDialectRegistry();
+
+ /// Return information about all available dialects in the registry in this
+ /// context.
+ std::vector<StringRef> getAvailableDialects();
/// Get a registered IR dialect with the given namespace. If an exact match is
/// not found, then return nullptr.
- Dialect *getRegisteredDialect(StringRef name);
+ Dialect *getLoadedDialect(StringRef name);
/// Get a registered IR dialect for the given derived dialect type. The
/// derived type must provide a static 'getDialectNamespace' method.
- template <typename T> T *getRegisteredDialect() {
- return static_cast<T *>(getRegisteredDialect(T::getDialectNamespace()));
+ template <typename T> T *getLoadedDialect() {
+ return static_cast<T *>(getLoadedDialect(T::getDialectNamespace()));
}
/// Get (or create) a dialect for the given derived dialect type. The derived
/// type must provide a static 'getDialectNamespace' method.
- template <typename T>
- T *getOrCreateDialect() {
- return static_cast<T *>(getOrCreateDialect(
- T::getDialectNamespace(), TypeID::get<T>(), [this]() {
+ template <typename T> T *getOrLoadDialect() {
+ return static_cast<T *>(
+ getOrLoadDialect(T::getDialectNamespace(), TypeID::get<T>(), [this]() {
std::unique_ptr<T> dialect(new T(this));
- dialect->dialectID = TypeID::get<T>();
return dialect;
}));
}
+ /// Deprecated: load all globally registered dialects into this context.
+ /// This method will be removed soon, it can be used temporarily as we're
+ /// phasing out the global registry.
+ void loadAllGloballyRegisteredDialects();
+
+ /// Get (or create) a dialect for the given derived dialect name.
+ /// The dialect will be loaded from the registry if no dialect is found.
+ /// If no dialect is loaded for this name and none is available in the
+ /// registry, returns nullptr.
+ Dialect *getOrLoadDialect(StringRef name);
+
/// Return true if we allow to create operation for unregistered dialects.
bool allowsUnregisteredDialects();
@@ -123,10 +145,12 @@ private:
const std::unique_ptr<MLIRContextImpl> impl;
/// Get a dialect for the provided namespace and TypeID: abort the program if
- /// a dialect exist for this namespace with different TypeID. Returns a
- /// pointer to the dialect owned by the context.
- Dialect *getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID,
- function_ref<std::unique_ptr<Dialect>()> ctor);
+ /// a dialect exist for this namespace with different TypeID. If a dialect has
+ /// not been loaded for this namespace/TypeID yet, use the provided ctor to
+ /// create one on the fly and load it. Returns a pointer to the dialect owned
+ /// by the context.
+ Dialect *getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
+ function_ref<std::unique_ptr<Dialect>()> ctor);
MLIRContext(const MLIRContext &) = delete;
void operator=(const MLIRContext &) = delete;
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 9cc57a6..a28410f 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -244,6 +244,11 @@ class Dialect {
// The description of the dialect.
string description = ?;
+ // A list of dialects this dialect will load on construction as dependencies.
+ // These are dialects that this dialect may involved in canonicalization
+ // pattern or interfaces.
+ list<string> dependentDialects = [];
+
// The C++ namespace that ops of this dialect should be placed into.
//
// By default, uses the name of the dialect as the only namespace. To avoid
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index b76b26f..a456616 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -35,29 +35,32 @@
namespace mlir {
+// Add all the MLIR dialects to the provided registry.
+inline void registerAllDialects(DialectRegistry &registry) {
+ registry.insert<acc::OpenACCDialect>();
+ registry.insert<AffineDialect>();
+ registry.insert<avx512::AVX512Dialect>();
+ registry.insert<gpu::GPUDialect>();
+ registry.insert<LLVM::LLVMAVX512Dialect>();
+ registry.insert<LLVM::LLVMDialect>();
+ registry.insert<linalg::LinalgDialect>();
+ registry.insert<scf::SCFDialect>();
+ registry.insert<omp::OpenMPDialect>();
+ registry.insert<quant::QuantizationDialect>();
+ registry.insert<spirv::SPIRVDialect>();
+ registry.insert<StandardOpsDialect>();
+ registry.insert<vector::VectorDialect>();
+ registry.insert<NVVM::NVVMDialect>();
+ registry.insert<ROCDL::ROCDLDialect>();
+ registry.insert<SDBMDialect>();
+ registry.insert<shape::ShapeDialect>();
+}
+
// This function should be called before creating any MLIRContext if one expect
// all the possible dialects to be made available to the context automatically.
inline void registerAllDialects() {
- static bool init_once = []() {
- registerDialect<acc::OpenACCDialect>();
- registerDialect<AffineDialect>();
- registerDialect<avx512::AVX512Dialect>();
- registerDialect<gpu::GPUDialect>();
- registerDialect<LLVM::LLVMAVX512Dialect>();
- registerDialect<LLVM::LLVMDialect>();
- registerDialect<linalg::LinalgDialect>();
- registerDialect<scf::SCFDialect>();
- registerDialect<omp::OpenMPDialect>();
- registerDialect<quant::QuantizationDialect>();
- registerDialect<spirv::SPIRVDialect>();
- registerDialect<StandardOpsDialect>();
- registerDialect<vector::VectorDialect>();
- registerDialect<NVVM::NVVMDialect>();
- registerDialect<ROCDL::ROCDLDialect>();
- registerDialect<SDBMDialect>();
- registerDialect<shape::ShapeDialect>();
- return true;
- }();
+ static bool init_once =
+ ([]() { registerAllDialects(getGlobalDialectRegistry()); }(), true);
(void)init_once;
}
} // namespace mlir
diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 7c0f9bd..ea361ae 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -9,6 +9,7 @@
#ifndef MLIR_PASS_PASS_H
#define MLIR_PASS_PASS_H
+#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/Pass/AnalysisManager.h"
#include "mlir/Pass/PassRegistry.h"
@@ -57,6 +58,13 @@ public:
/// Returns the derived pass name.
virtual StringRef getName() const = 0;
+ /// Register dependent dialects for the current pass.
+ /// A pass is expected to register the dialects it will create operations for,
+ /// other than dialect that exists in the input. For example, a pass that
+ /// converts from Linalg to Affine would register the Affine dialect but does
+ /// not need to register Linalg.
+ virtual void getDependentDialects(DialectRegistry &registry) const {}
+
/// Returns the command line argument used when registering this pass. Return
/// an empty string if one does not exist.
virtual StringRef getArgument() const {
diff --git a/mlir/include/mlir/Pass/PassBase.td b/mlir/include/mlir/Pass/PassBase.td
index 54b4403..749d042 100644
--- a/mlir/include/mlir/Pass/PassBase.td
+++ b/mlir/include/mlir/Pass/PassBase.td
@@ -78,6 +78,9 @@ class PassBase<string passArg, string base> {
// A C++ constructor call to create an instance of this pass.
code constructor = [{}];
+ // A list of dialects this pass may produce operations in.
+ list<string> dependentDialects = [];
+
// A set of options provided by this pass.
list<Option> options = [];
diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 9cbfb0b..29e7c07 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -9,6 +9,7 @@
#ifndef MLIR_PASS_PASSMANAGER_H
#define MLIR_PASS_PASSMANAGER_H
+#include "mlir/IR/Dialect.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/Optional.h"
@@ -58,6 +59,14 @@ public:
pass_iterator end();
iterator_range<pass_iterator> getPasses() { return {begin(), end()}; }
+ using const_pass_iterator = llvm::pointee_iterator<
+ std::vector<std::unique_ptr<Pass>>::const_iterator>;
+ const_pass_iterator begin() const;
+ const_pass_iterator end() const;
+ iterator_range<const_pass_iterator> getPasses() const {
+ return {begin(), end()};
+ }
+
/// Run the held passes over the given operation.
LogicalResult run(Operation *op, AnalysisManager am);
@@ -100,6 +109,11 @@ public:
/// Merge the pass statistics of this class into 'other'.
void mergeStatisticsInto(OpPassManager &other);
+ /// Register dependent dialects for the current pass manager.
+ /// This is forwarding to every pass in this PassManager, see the
+ /// documentation for the same method on the Pass class.
+ void getDependentDialects(DialectRegistry &dialects) const;
+
private:
OpPassManager(OperationName name, bool verifyPasses);
diff --git a/mlir/include/mlir/Support/MlirOptMain.h b/mlir/include/mlir/Support/MlirOptMain.h
index f235ea3..741276b 100644
--- a/mlir/include/mlir/Support/MlirOptMain.h
+++ b/mlir/include/mlir/Support/MlirOptMain.h
@@ -22,10 +22,15 @@ namespace mlir {
struct LogicalResult;
class PassPipelineCLParser;
+/// Run an passPipeline on the provided memory buffer loaded as an MLIRModule.
+/// The preloadDialectsInContext option will trigger an option upfront loading
+/// of all dialects from the global registry in the MLIRContext. This option is
+/// deprecated and will be removed soon.
LogicalResult MlirOptMain(llvm::raw_ostream &os,
std::unique_ptr<llvm::MemoryBuffer> buffer,
const PassPipelineCLParser &passPipeline,
bool splitInputFile, bool verifyDiagnostics,
- bool verifyPasses, bool allowUnregisteredDialects);
+ bool verifyPasses, bool allowUnregisteredDialects,
+ bool preloadDialectsInContext = false);
} // end namespace mlir
diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index 5e85806..99217d8 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -14,6 +14,7 @@
#include "mlir/Support/LLVM.h"
#include <string>
+#include <vector>
namespace llvm {
class Record;
@@ -25,7 +26,7 @@ namespace tblgen {
// and provides helper methods for accessing them.
class Dialect {
public:
- explicit Dialect(const llvm::Record *def) : def(def) {}
+ explicit Dialect(const llvm::Record *def);
// Returns the name of this dialect.
StringRef getName() const;
@@ -43,6 +44,10 @@ public:
// Returns the description of the dialect. Returns empty string if none.
StringRef getDescription() const;
+ // Returns the list of dialect (class names) that this dialect depends on.
+ // These are dialects that will be loaded on construction of this dialect.
+ ArrayRef<StringRef> getDependentDialects() const;
+
// Returns the dialects extra class declaration code.
llvm::Optional<StringRef> getExtraClassDeclaration() const;
@@ -70,6 +75,7 @@ public:
private:
const llvm::Record *def;
+ std::vector<StringRef> dependentDialects;
};
} // end namespace tblgen
} // end namespace mlir
diff --git a/mlir/include/mlir/TableGen/Pass.h b/mlir/include/mlir/TableGen/Pass.h
index 02427e4..968c854 100644
--- a/mlir/include/mlir/TableGen/Pass.h
+++ b/mlir/include/mlir/TableGen/Pass.h
@@ -94,6 +94,9 @@ public:
/// Return the C++ constructor call to create an instance of this pass.
StringRef getConstructor() const;
+ /// Return the dialects this pass needs to be registered.
+ ArrayRef<StringRef> getDependentDialects() const;
+
/// Return the options provided by this pass.
ArrayRef<PassOption> getOptions() const;
@@ -104,6 +107,7 @@ public:
private:
const llvm::Record *def;
+ std::vector<StringRef> dependentDialects;
std::vector<PassOption> options;
std::vector<PassStatistic> statistics;
};
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 7787805..3292d5e 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -162,6 +162,8 @@ def BufferPlacement : FunctionPass<"buffer-placement"> {
}];
let constructor = "mlir::createBufferPlacementPass()";
+ // TODO: this pass likely shouldn't depend on Linalg?
+ let dependentDialects = ["linalg::LinalgDialect"];
}
def Canonicalizer : Pass<"canonicalize"> {
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 4ccfb45..2417c1d 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -9,9 +9,11 @@
#include "mlir-c/IR.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
+#include "mlir/InitAllDialects.h"
#include "mlir/Parser.h"
#include "llvm/Support/raw_ostream.h"
@@ -89,12 +91,17 @@ private:
/* ========================================================================== */
MlirContext mlirContextCreate() {
- auto *context = new MLIRContext;
+ auto *context = new MLIRContext(false);
return wrap(context);
}
void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
+void mlirContextLoadAllDialects(MlirContext context) {
+ registerAllDialects(unwrap(context));
+ getGlobalDialectRegistry().loadAll(unwrap(context));
+}
+
/* ========================================================================== */
/* Location API. */
/* ========================================================================== */
diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
index 1ebf481..4267393 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
@@ -16,6 +16,7 @@
#include "../PassDetail.h"
#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Serialization.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 7b57854..0460d98 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h
index 6da0bc8..7fa5a5a 100644
--- a/mlir/lib/Conversion/PassDetail.h
+++ b/mlir/lib/Conversion/PassDetail.h
@@ -12,11 +12,43 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
+class AffineDialect;
+class StandardOpsDialect;
+
+// Forward declaration from Dialect.h
+template <typename ConcreteDialect>
+void registerDialect(DialectRegistry &registry);
namespace gpu {
+class GPUDialect;
class GPUModuleOp;
} // end namespace gpu
+namespace LLVM {
+class LLVMDialect;
+class LLVMAVX512Dialect;
+} // end namespace LLVM
+
+namespace NVVM {
+class NVVMDialect;
+} // end namespace NVVM
+
+namespace ROCDL {
+class ROCDLDialect;
+} // end namespace ROCDL
+
+namespace scf {
+class SCFDialect;
+} // end namespace scf
+
+namespace spirv {
+class SPIRVDialect;
+} // end namespace spirv
+
+namespace vector {
+class VectorDialect;
+} // end namespace vector
+
#define GEN_PASS_CLASSES
#include "mlir/Conversion/Passes.h.inc"
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index efe4a3c..6096703 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -125,7 +125,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
/// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
const LowerToLLVMOptions &options)
- : llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()),
+ : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()),
options(options) {
assert(llvmDialect && "LLVM IR dialect is not registered");
if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
index 19643d2..a2e608d 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
@@ -14,6 +14,7 @@
#include "../PassDetail.h"
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
diff --git a/mlir/lib/Dialect/Affine/Transforms/PassDetail.h b/mlir/lib/Dialect/Affine/Transforms/PassDetail.h
index 3bae059..da8f7ac 100644
--- a/mlir/lib/Dialect/Affine/Transforms/PassDetail.h
+++ b/mlir/lib/Dialect/Affine/Transforms/PassDetail.h
@@ -12,6 +12,16 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
+// Forward declaration from Dialect.h
+template <typename ConcreteDialect>
+void registerDialect(DialectRegistry &registry);
+
+namespace linalg {
+class LinalgDialect;
+} // end namespace linalg
+namespace vector {
+class VectorDialect;
+} // end namespace vector
#define GEN_PASS_CLASSES
#include "mlir/Dialect/Affine/Passes.h.inc"
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 009699b..bf18c6c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1224,6 +1224,7 @@ template <typename NamedStructuredOpType>
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 8> operandsInfo;
+ result.getContext()->getOrLoadDialect<StandardOpsDialect>();
// Optional attributes may be added.
if (parser.parseOperandList(operandsInfo) ||
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h
index 7fa05ff..0415aeb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h
+++ b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h
@@ -9,9 +9,18 @@
#ifndef DIALECT_LINALG_TRANSFORMS_PASSDETAIL_H_
#define DIALECT_LINALG_TRANSFORMS_PASSDETAIL_H_
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
+// Forward declaration from Dialect.h
+template <typename ConcreteDialect>
+void registerDialect(DialectRegistry &registry);
+
+namespace scf {
+class SCFDialect;
+} // end namespace scf
#define GEN_PASS_CLASSES
#include "mlir/Dialect/Linalg/Passes.h.inc"
diff --git a/mlir/lib/Dialect/SCF/Transforms/PassDetail.h b/mlir/lib/Dialect/SCF/Transforms/PassDetail.h
index 95f8636..6fa7f22 100644
--- a/mlir/lib/Dialect/SCF/Transforms/PassDetail.h
+++ b/mlir/lib/Dialect/SCF/Transforms/PassDetail.h
@@ -12,6 +12,11 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
+// Forward declaration from Dialect.h
+template <typename ConcreteDialect>
+void registerDialect(DialectRegistry &registry);
+
+class AffineDialect;
#define GEN_PASS_CLASSES
#include "mlir/Dialect/SCF/Passes.h.inc"
diff --git a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp
index 435c7fe..a1971c3 100644
--- a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp
+++ b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp
@@ -517,7 +517,7 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
SDBMDialect *dialect;
} converter;
- converter.dialect = affine.getContext()->getRegisteredDialect<SDBMDialect>();
+ converter.dialect = affine.getContext()->getOrLoadDialect<SDBMDialect>();
if (auto result = converter.visit(affine))
return result;
diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp
index 7959183..2b18adb 100644
--- a/mlir/lib/ExecutionEngine/JitRunner.cpp
+++ b/mlir/lib/ExecutionEngine/JitRunner.cpp
@@ -259,7 +259,9 @@ int mlir::JitRunnerMain(
}
}
- MLIRContext context;
+ MLIRContext context(/*loadAllDialects=*/false);
+ registerAllDialects(&context);
+
auto m = parseMLIRInput(options.inputFilename, &context);
if (!m) {
llvm::errs() << "could not parse the input IR\n";
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 555bb2b..f2f0a63 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -27,21 +27,25 @@ DialectAsmParser::~DialectAsmParser() {}
//===----------------------------------------------------------------------===//
/// Registry for all dialect allocation functions.
-static llvm::ManagedStatic<llvm::MapVector<TypeID, DialectAllocatorFunction>>
- dialectRegistry;
-
-void Dialect::registerDialectAllocator(
- TypeID typeID, const DialectAllocatorFunction &function) {
- assert(function &&
- "Attempting to register an empty dialect initialize function");
- dialectRegistry->insert({typeID, function});
+static llvm::ManagedStatic<DialectRegistry> dialectRegistry;
+DialectRegistry &mlir::getGlobalDialectRegistry() { return *dialectRegistry; }
+
+Dialect *DialectRegistry::loadByName(StringRef name, MLIRContext *context) {
+ auto it = registry.find(std::string(name));
+ if (it == registry.end())
+ return nullptr;
+ return it->second.second(context);
}
-/// Registers all dialects and hooks from the global registries with the
-/// specified MLIRContext.
-void mlir::registerAllDialects(MLIRContext *context) {
- for (const auto &it : *dialectRegistry)
- it.second(context);
+void DialectRegistry::insert(TypeID typeID, StringRef name,
+ DialectAllocatorFunction ctor) {
+ auto inserted =
+ registry.insert(std::make_pair(name, std::make_pair(typeID, ctor)));
+ if (!inserted.second && inserted.first->second.first != typeID) {
+ llvm::report_fatal_error(
+ "Trying to register different dialects for the same namespace: " +
+ name);
+ }
}
//===----------------------------------------------------------------------===//
@@ -119,7 +123,7 @@ DialectInterface::~DialectInterface() {}
DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
MLIRContext *ctx, TypeID interfaceKind) {
- for (auto *dialect : ctx->getRegisteredDialects()) {
+ for (auto *dialect : ctx->getLoadedDialects()) {
if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
interfaces.insert(interface);
orderedInterfaces.push_back(interface);
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 0d66070..ed8e3db 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -31,10 +31,13 @@
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Support/RWMutex.h"
#include "llvm/Support/raw_ostream.h"
#include <memory>
+#define DEBUG_TYPE "mlircontext"
+
using namespace mlir;
using namespace mlir::detail;
@@ -275,7 +278,8 @@ public:
/// This is a list of dialects that are created referring to this context.
/// The MLIRContext owns the objects.
- std::vector<std::unique_ptr<Dialect>> dialects;
+ DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects;
+ DialectRegistry dialectsRegistry;
/// This is a mapping from operation name to AbstractOperation for registered
/// operations.
@@ -346,7 +350,7 @@ public:
};
} // end namespace mlir
-MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
+MLIRContext::MLIRContext(bool loadAllDialects) : impl(new MLIRContextImpl()) {
// Initialize values based on the command line flags if they were provided.
if (clOptions.isConstructed()) {
disableMultithreading(clOptions->disableThreading);
@@ -355,8 +359,9 @@ MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
}
// Register dialects with this context.
- getOrCreateDialect<BuiltinDialect>();
- registerAllDialects(this);
+ getOrLoadDialect<BuiltinDialect>();
+ if (loadAllDialects)
+ loadAllGloballyRegisteredDialects();
// Initialize several common attributes and types to avoid the need to lock
// the context when accessing them.
@@ -438,54 +443,72 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
// Dialect and Operation Registration
//===----------------------------------------------------------------------===//
+DialectRegistry &MLIRContext::getDialectRegistry() {
+ return impl->dialectsRegistry;
+}
+
/// Return information about all registered IR dialects.
-std::vector<Dialect *> MLIRContext::getRegisteredDialects() {
+std::vector<Dialect *> MLIRContext::getLoadedDialects() {
std::vector<Dialect *> result;
- result.reserve(impl->dialects.size());
- for (auto &dialect : impl->dialects)
- result.push_back(dialect.get());
+ result.reserve(impl->loadedDialects.size());
+ for (auto &dialect : impl->loadedDialects) {
+ result.push_back(dialect.second.get());
+ }
+ llvm::sort(result, [](Dialect *lhs, Dialect *rhs) {
+ return lhs->getNamespace() < rhs->getNamespace();
+ });
+ return result;
+}
+std::vector<StringRef> MLIRContext::getAvailableDialects() {
+ std::vector<StringRef> result;
+ for (auto &dialect : impl->dialectsRegistry)
+ result.push_back(dialect.first);
return result;
}
/// Get a registered IR dialect with the given namespace. If none is found,
/// then return nullptr.
-Dialect *MLIRContext::getRegisteredDialect(StringRef name) {
+Dialect *MLIRContext::getLoadedDialect(StringRef name) {
// Dialects are sorted by name, so we can use binary search for lookup.
- auto it = llvm::lower_bound(
- impl->dialects, name,
- [](const auto &lhs, StringRef rhs) { return lhs->getNamespace() < rhs; });
- return (it != impl->dialects.end() && (*it)->getNamespace() == name)
- ? (*it).get()
- : nullptr;
+ auto it = impl->loadedDialects.find(name);
+ return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr;
+}
+
+Dialect *MLIRContext::getOrLoadDialect(StringRef name) {
+ Dialect *dialect = getLoadedDialect(name);
+ if (dialect)
+ return dialect;
+ return impl->dialectsRegistry.loadByName(name, this);
}
/// Get a dialect for the provided namespace and TypeID: abort the program if a
/// dialect exist for this namespace with different TypeID. Returns a pointer to
/// the dialect owned by the context.
Dialect *
-MLIRContext::getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID,
- function_ref<std::unique_ptr<Dialect>()> ctor) {
+MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
+ function_ref<std::unique_ptr<Dialect>()> ctor) {
auto &impl = getImpl();
// Get the correct insertion position sorted by namespace.
- auto insertPt =
- llvm::lower_bound(impl.dialects, nullptr,
- [&](const std::unique_ptr<Dialect> &lhs,
- const std::unique_ptr<Dialect> &rhs) {
- if (!lhs)
- return dialectNamespace < rhs->getNamespace();
- return lhs->getNamespace() < dialectNamespace;
- });
+ std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace];
+
+ if (!dialect) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Load new dialect in Context" << dialectNamespace);
+ dialect = ctor();
+ assert(dialect && "dialect ctor failed");
+ return dialect.get();
+ }
// Abort if dialect with namespace has already been registered.
- if (insertPt != impl.dialects.end() &&
- (*insertPt)->getNamespace() == dialectNamespace) {
- if ((*insertPt)->getTypeID() == dialectID)
- return insertPt->get();
+ if (dialect->getTypeID() != dialectID)
llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
"' has already been registered");
- }
- auto it = impl.dialects.insert(insertPt, ctor());
- return &**it;
+
+ return dialect.get();
+}
+
+void MLIRContext::loadAllGloballyRegisteredDialects() {
+ getGlobalDialectRegistry().loadAll(this);
}
bool MLIRContext::allowsUnregisteredDialects() {
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 152ed01..dce570a 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -214,7 +214,7 @@ Dialect *Operation::getDialect() {
// If this operation hasn't been registered or doesn't have abstract
// operation, try looking up the dialect name in the context.
- return getContext()->getRegisteredDialect(getName().getDialect());
+ return getContext()->getLoadedDialect(getName().getDialect());
}
Region *Operation::getParentRegion() {
diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp
index b1aed88..4caf989 100644
--- a/mlir/lib/IR/Verifier.cpp
+++ b/mlir/lib/IR/Verifier.cpp
@@ -50,7 +50,7 @@ public:
Dialect *getDialectForAttribute(const NamedAttribute &attr) {
assert(attr.first.strref().contains('.') && "expected dialect attribute");
auto dialectNamePair = attr.first.strref().split('.');
- return ctx->getRegisteredDialect(dialectNamePair.first);
+ return ctx->getLoadedDialect(dialectNamePair.first);
}
private:
@@ -218,7 +218,7 @@ LogicalResult OperationVerifier::verifyOperation(Operation &op) {
auto it = dialectAllowsUnknownOps.find(dialectPrefix);
if (it == dialectAllowsUnknownOps.end()) {
// If the operation dialect is registered, query it directly.
- if (auto *dialect = ctx->getRegisteredDialect(dialectPrefix))
+ if (auto *dialect = ctx->getLoadedDialect(dialectPrefix))
it = dialectAllowsUnknownOps
.try_emplace(dialectPrefix, dialect->allowsUnknownOperations())
.first;
diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index 1c1261e..37ee938 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -12,6 +12,7 @@
#include "Parser.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Dialect.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/StandardTypes.h"
#include "llvm/ADT/StringExtras.h"
@@ -246,6 +247,11 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
return emitError("duplicate key in dictionary attribute");
consumeToken();
+ // Lazy load a dialect in the context if there is a possible namespace.
+ auto splitName = nameId->strref().split('.');
+ if (!splitName.second.empty())
+ getContext()->getOrLoadDialect(splitName.first);
+
// Try to parse the '=' for the attribute value.
if (!consumeIf(Token::equal)) {
// If there is no '=', we treat this as a unit attribute.
@@ -817,7 +823,9 @@ Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
return (emitError("expected dialect namespace"), nullptr);
auto name = getToken().getStringValue();
- auto *dialect = builder.getContext()->getRegisteredDialect(name);
+ // Lazy load a dialect in the context if there is a possible namespace.
+ Dialect *dialect = builder.getContext()->getOrLoadDialect(name);
+
// TODO: Allow for having an unknown dialect on an opaque
// attribute. Otherwise, it can't be roundtripped without having the dialect
// registered.
diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp
index 3b522a8..d45ddf0 100644
--- a/mlir/lib/Parser/DialectSymbolParser.cpp
+++ b/mlir/lib/Parser/DialectSymbolParser.cpp
@@ -526,7 +526,8 @@ Attribute Parser::parseExtendedAttr(Type type) {
return Attribute();
// If we found a registered dialect, then ask it to parse the attribute.
- if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
+ if (Dialect *dialect =
+ builder.getContext()->getOrLoadDialect(dialectName)) {
return parseSymbol<Attribute>(
symbolData, state.context, state.symbols, [&](Parser &parser) {
CustomDialectAsmParser customParser(symbolData, parser);
@@ -563,7 +564,9 @@ Type Parser::parseExtendedType() {
[&](StringRef dialectName, StringRef symbolData,
llvm::SMLoc loc) -> Type {
// If we found a registered dialect, then ask it to parse the type.
- if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
+ auto *dialect = state.context->getOrLoadDialect(dialectName);
+
+ if (dialect) {
return parseSymbol<Type>(
symbolData, state.context, state.symbols, [&](Parser &parser) {
CustomDialectAsmParser customParser(symbolData, parser);
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 3a995a4..837b08c 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -12,6 +12,7 @@
#include "Parser.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Dialect.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser.h"
@@ -727,7 +728,7 @@ Operation *OperationParser::parseGenericOperation() {
// Get location information for the operation.
auto srcLocation = getEncodedSourceLocation(getToken().getLoc());
- auto name = getToken().getStringValue();
+ std::string name = getToken().getStringValue();
if (name.empty())
return (emitError("empty operation name is invalid"), nullptr);
if (name.find('\0') != StringRef::npos)
@@ -737,6 +738,15 @@ Operation *OperationParser::parseGenericOperation() {
OperationState result(srcLocation, name);
+ // Lazy load dialects in the context as needed.
+ if (!result.name.getAbstractOperation()) {
+ StringRef dialectName = StringRef(name).split('.').first;
+ if (!getContext()->getLoadedDialect(dialectName) &&
+ getContext()->getOrLoadDialect(dialectName)) {
+ result.name = OperationName(name, getContext());
+ }
+ }
+
// Parse the operand list.
SmallVector<SSAUseInfo, 8> operandInfos;
if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
@@ -1442,17 +1452,28 @@ private:
Operation *
OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
- auto opLoc = getToken().getLoc();
- auto opName = getTokenSpelling();
+ llvm::SMLoc opLoc = getToken().getLoc();
+ StringRef opName = getTokenSpelling();
auto *opDefinition = AbstractOperation::lookup(opName, getContext());
- if (!opDefinition && !opName.contains('.')) {
- // If the operation name has no namespace prefix we treat it as a standard
- // operation and prefix it with "std".
- // TODO: Would it be better to just build a mapping of the registered
- // operations in the standard dialect?
- opDefinition =
- AbstractOperation::lookup(Twine("std." + opName).str(), getContext());
+ if (!opDefinition) {
+ if (opName.contains('.')) {
+ // This op has a dialect, we try to check if we can register it in the
+ // context on the fly.
+ StringRef dialectName = opName.split('.').first;
+ if (!getContext()->getLoadedDialect(dialectName) &&
+ getContext()->getOrLoadDialect(dialectName)) {
+ opDefinition = AbstractOperation::lookup(opName, getContext());
+ }
+ } else {
+ // If the operation name has no namespace prefix we treat it as a standard
+ // operation and prefix it with "std".
+ // TODO: Would it be better to just build a mapping of the registered
+ // operations in the standard dialect?
+ if (getContext()->getOrLoadDialect("std"))
+ opDefinition = AbstractOperation::lookup(Twine("std." + opName).str(),
+ getContext());
+ }
}
if (!opDefinition) {
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index b791bf4..9debd11 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -290,6 +290,13 @@ OpPassManager::pass_iterator OpPassManager::begin() {
}
OpPassManager::pass_iterator OpPassManager::end() { return impl->passes.end(); }
+OpPassManager::const_pass_iterator OpPassManager::begin() const {
+ return impl->passes.begin();
+}
+OpPassManager::const_pass_iterator OpPassManager::end() const {
+ return impl->passes.end();
+}
+
/// Run all of the passes in this manager over the current operation.
LogicalResult OpPassManager::run(Operation *op, AnalysisManager am) {
// Run each of the held passes.
@@ -346,6 +353,16 @@ void OpPassManager::printAsTextualPipeline(raw_ostream &os) {
::printAsTextualPipeline(impl->passes, os);
}
+static void registerDialectsForPipeline(const OpPassManager &pm,
+ DialectRegistry &dialects) {
+ for (const Pass &pass : pm.getPasses())
+ pass.getDependentDialects(dialects);
+}
+
+void OpPassManager::getDependentDialects(DialectRegistry &dialects) const {
+ registerDialectsForPipeline(*this, dialects);
+}
+
//===----------------------------------------------------------------------===//
// OpToOpPassAdaptor
//===----------------------------------------------------------------------===//
@@ -378,6 +395,11 @@ OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) {
mgrs.emplace_back(std::move(mgr));
}
+void OpToOpPassAdaptor::getDependentDialects(DialectRegistry &dialects) const {
+ for (auto &pm : mgrs)
+ pm.getDependentDialects(dialects);
+}
+
/// Merge the current pass adaptor into given 'rhs'.
void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) {
for (auto &pm : mgrs) {
@@ -721,6 +743,11 @@ LogicalResult PassManager::run(ModuleOp module) {
// pipeline.
getImpl().coalesceAdjacentAdaptorPasses();
+ // Register all dialects for the current pipeline.
+ DialectRegistry dependent_dialects;
+ getDependentDialects(dependent_dialects);
+ dependent_dialects.loadAll(module.getContext());
+
// Construct an analysis manager for the pipeline.
ModuleAnalysisManager am(module, instrumentor.get());
diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h
index 2342a1a..f69701d 100644
--- a/mlir/lib/Pass/PassDetail.h
+++ b/mlir/lib/Pass/PassDetail.h
@@ -43,6 +43,10 @@ public:
/// Returns the pass managers held by this adaptor.
MutableArrayRef<OpPassManager> getPassManagers() { return mgrs; }
+ /// Populate the set of dependent dialects for the passes in the current
+ /// adaptor.
+ void getDependentDialects(DialectRegistry &dialects) const override;
+
/// Return the async pass managers held by this parallel adaptor.
MutableArrayRef<SmallVector<OpPassManager, 1>> getParallelPassManagers() {
return asyncExecutors;
diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp
index 25e1970..e450fb3 100644
--- a/mlir/lib/Support/MlirOptMain.cpp
+++ b/mlir/lib/Support/MlirOptMain.cpp
@@ -75,13 +75,17 @@ static LogicalResult processBuffer(raw_ostream &os,
std::unique_ptr<MemoryBuffer> ownedBuffer,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
+ bool preloadDialectsInContext,
const PassPipelineCLParser &passPipeline) {
// Tell sourceMgr about this buffer, which is what the parser will pick up.
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
// Parse the input file.
- MLIRContext context;
+ MLIRContext context(/*loadAllDialects=*/false);
+ registerAllDialects(&context);
+ if (preloadDialectsInContext)
+ context.getDialectRegistry().loadAll(&context);
context.allowUnregisteredDialects(allowUnregisteredDialects);
context.printOpOnDiagnostic(!verifyDiagnostics);
@@ -111,7 +115,8 @@ LogicalResult mlir::MlirOptMain(raw_ostream &os,
const PassPipelineCLParser &passPipeline,
bool splitInputFile, bool verifyDiagnostics,
bool verifyPasses,
- bool allowUnregisteredDialects) {
+ bool allowUnregisteredDialects,
+ bool preloadDialectsInContext) {
// The split-input-file mode is a very specific mode that slices the file
// up into small pieces and checks each independently.
if (splitInputFile)
@@ -120,10 +125,11 @@ LogicalResult mlir::MlirOptMain(raw_ostream &os,
[&](std::unique_ptr<MemoryBuffer> chunkBuffer, raw_ostream &os) {
return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
verifyPasses, allowUnregisteredDialects,
- passPipeline);
+ preloadDialectsInContext, passPipeline);
},
os);
return processBuffer(os, std::move(buffer), verifyDiagnostics, verifyPasses,
- allowUnregisteredDialects, passPipeline);
+ allowUnregisteredDialects, preloadDialectsInContext,
+ passPipeline);
}
diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp
index 6af77e7..8aee067 100644
--- a/mlir/lib/TableGen/Dialect.cpp
+++ b/mlir/lib/TableGen/Dialect.cpp
@@ -16,6 +16,11 @@
using namespace mlir;
using namespace mlir::tblgen;
+Dialect::Dialect(const llvm::Record *def) : def(def) {
+ for (StringRef dialect : def->getValueAsListOfStrings("dependentDialects"))
+ dependentDialects.push_back(dialect);
+}
+
StringRef Dialect::getName() const { return def->getValueAsString("name"); }
StringRef Dialect::getCppNamespace() const {
@@ -46,6 +51,10 @@ StringRef Dialect::getDescription() const {
return getAsStringOrEmpty(*def, "description");
}
+ArrayRef<StringRef> Dialect::getDependentDialects() const {
+ return dependentDialects;
+}
+
llvm::Optional<StringRef> Dialect::getExtraClassDeclaration() const {
auto value = def->getValueAsString("extraClassDeclaration");
return value.empty() ? llvm::Optional<StringRef>() : value;
diff --git a/mlir/lib/TableGen/Pass.cpp b/mlir/lib/TableGen/Pass.cpp
index 4bc46b6..f961806 100644
--- a/mlir/lib/TableGen/Pass.cpp
+++ b/mlir/lib/TableGen/Pass.cpp
@@ -69,6 +69,8 @@ Pass::Pass(const llvm::Record *def) : def(def) {
options.push_back(PassOption(init));
for (auto *init : def->getValueAsListOfDefs("statistics"))
statistics.push_back(PassStatistic(init));
+ for (StringRef dialect : def->getValueAsListOfStrings("dependentDialects"))
+ dependentDialects.push_back(dialect);
}
StringRef Pass::getArgument() const {
@@ -88,6 +90,9 @@ StringRef Pass::getDescription() const {
StringRef Pass::getConstructor() const {
return def->getValueAsString("constructor");
}
+ArrayRef<StringRef> Pass::getDependentDialects() const {
+ return dependentDialects;
+}
ArrayRef<PassOption> Pass::getOptions() const { return options; }
diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index 470044b..1d01569 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -836,6 +836,7 @@ LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) {
OwningModuleRef
mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
MLIRContext *context) {
+ context->getOrLoadDialect<LLVMDialect>();
OwningModuleRef module(ModuleOp::create(
FileLineColLoc::get("", /*line=*/0, /*column=*/0, context)));
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 215c191..027422b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -302,8 +302,7 @@ ModuleTranslation::ModuleTranslation(Operation *module,
: mlirModule(module), llvmModule(std::move(llvmModule)),
debugTranslation(
std::make_unique<DebugTranslation>(module, *this->llvmModule)),
- ompDialect(
- module->getContext()->getRegisteredDialect<omp::OpenMPDialect>()),
+ ompDialect(module->getContext()->getOrLoadDialect<omp::OpenMPDialect>()),
typeTranslator(this->llvmModule->getContext()) {
assert(satisfiesLLVMModule(mlirModule) &&
"mlirModule should honor LLVM's module semantics.");
@@ -944,7 +943,7 @@ ModuleTranslation::lookupValues(ValueRange values) {
std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(
Operation *m, llvm::LLVMContext &llvmContext, StringRef name) {
- auto *dialect = m->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
+ auto *dialect = m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>();
assert(dialect && "LLVM dialect must be registered");
auto llvmModule = std::make_unique<llvm::Module>(name, llvmContext);
diff --git a/mlir/lib/Transforms/PassDetail.h b/mlir/lib/Transforms/PassDetail.h
index c6f7e22..220ed1a 100644
--- a/mlir/lib/Transforms/PassDetail.h
+++ b/mlir/lib/Transforms/PassDetail.h
@@ -12,6 +12,13 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
+// Forward declaration from Dialect.h
+template <typename ConcreteDialect>
+void registerDialect(DialectRegistry &registry);
+
+namespace linalg {
+class LinalgDialect;
+} // end namespace linalg
#define GEN_PASS_CLASSES
#include "mlir/Transforms/Passes.h.inc"
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index d6ab351..df2d32f 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -243,6 +243,7 @@ static void printFirstOfEach(MlirOperation operation) {
int main() {
mlirRegisterAllDialects();
MlirContext ctx = mlirContextCreate();
+ mlirContextLoadAllDialects(ctx);
MlirLocation location = mlirLocationUnknownGet(ctx);
MlirModule moduleOp = makeAdd(ctx, location);
diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index 3fcfcf2..e5766f0 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -36,16 +36,16 @@ using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
static MLIRContext &globalContext() {
- static bool init_once = []() {
- registerDialect<AffineDialect>();
- registerDialect<linalg::LinalgDialect>();
- registerDialect<scf::SCFDialect>();
- registerDialect<StandardOpsDialect>();
- registerDialect<vector::VectorDialect>();
+ static thread_local MLIRContext context(/*loadAllDialects=*/false);
+ static thread_local bool init_once = [&]() {
+ context.getOrLoadDialect<AffineDialect>();
+ context.getOrLoadDialect<scf::SCFDialect>();
+ context.getOrLoadDialect<linalg::LinalgDialect>();
+ context.getOrLoadDialect<StandardOpsDialect>();
+ context.getOrLoadDialect<vector::VectorDialect>();
return true;
}();
(void)init_once;
- static thread_local MLIRContext context;
context.allowUnregisteredDialects();
return context;
}
diff --git a/mlir/test/SDBM/sdbm-api-test.cpp b/mlir/test/SDBM/sdbm-api-test.cpp
index 0b58e29..ddefc52 100644
--- a/mlir/test/SDBM/sdbm-api-test.cpp
+++ b/mlir/test/SDBM/sdbm-api-test.cpp
@@ -19,18 +19,19 @@
using namespace mlir;
-// Load the SDBM dialect
-static DialectRegistration<SDBMDialect> SDBMRegistration;
static MLIRContext *ctx() {
- static thread_local MLIRContext context;
+ static thread_local MLIRContext context(/*loadAllDialects=*/false);
+ static thread_local bool once =
+ (context.getOrLoadDialect<SDBMDialect>(), true);
+ (void)once;
return &context;
}
static SDBMDialect *dialect() {
static thread_local SDBMDialect *d = nullptr;
if (!d) {
- d = ctx()->getRegisteredDialect<SDBMDialect>();
+ d = ctx()->getOrLoadDialect<SDBMDialect>();
}
return d;
}
diff --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
index a6719b0..cfac2dc 100644
--- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
@@ -14,6 +14,7 @@
#include "mlir/Analysis/NestedMatcher.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Diagnostics.h"
@@ -72,6 +73,9 @@ struct VectorizerTestPass
: public PassWrapper<VectorizerTestPass, FunctionPass> {
static constexpr auto kTestAffineMapOpName = "test_affine_map";
static constexpr auto kTestAffineMapAttrName = "affine_map";
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<vector::VectorDialect>();
+ }
void runOnFunction() override;
void testVectorShapeRatio(llvm::raw_ostream &outs);
diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index 0c1069f..03c425d 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -30,7 +30,7 @@ void PrintOpAvailability::runOnFunction() {
auto f = getFunction();
llvm::outs() << f.getName() << "\n";
- Dialect *spvDialect = getContext().getRegisteredDialect("spv");
+ Dialect *spvDialect = getContext().getLoadedDialect("spv");
f.getOperation()->walk([&](Operation *op) {
if (op->getDialect() != spvDialect)
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index f2a17a9..be5d799 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -768,6 +768,10 @@ struct TestTypeConversionProducer
struct TestTypeConversionDriver
: public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<TestDialect>();
+ }
+
void runOnOperation() override {
// Initialize the type converter.
TypeConverter converter;
diff --git a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp
index c043d0f..0c72b6c 100644
--- a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp
+++ b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@@ -19,6 +20,9 @@ using namespace mlir;
namespace {
struct TestAllReduceLoweringPass
: public PassWrapper<TestAllReduceLoweringPass, OperationPass<ModuleOp>> {
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<StandardOpsDialect>();
+ }
void runOnOperation() override {
OwningRewritePatternList patterns;
populateGpuRewritePatterns(&getContext(), patterns);
diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
index 5ad441aa..6cc0924 100644
--- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp
+++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
@@ -116,6 +116,10 @@ struct TestBufferPlacementPreparationPass
patterns->insert<GenericOpConverter>(context, placer, converter);
}
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<linalg::LinalgDialect>();
+ }
+
void runOnOperation() override {
MLIRContext &context = this->getContext();
ConversionTarget target(context);
diff --git a/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp b/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp
index 08862dd..339ec1a 100644
--- a/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp
+++ b/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp
@@ -13,6 +13,9 @@
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/GPU/MemoryPromotion.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/Pass/Pass.h"
@@ -26,6 +29,11 @@ namespace {
class TestGpuMemoryPromotionPass
: public PassWrapper<TestGpuMemoryPromotionPass,
OperationPass<gpu::GPUFuncOp>> {
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<StandardOpsDialect>();
+ registry.insert<scf::SCFDialect>();
+ }
+
void runOnOperation() override {
gpu::GPUFuncOp op = getOperation();
for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
diff --git a/mlir/test/lib/Transforms/TestLinalgHoisting.cpp b/mlir/test/lib/Transforms/TestLinalgHoisting.cpp
index d1e478f..5d4031f 100644
--- a/mlir/test/lib/Transforms/TestLinalgHoisting.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgHoisting.cpp
@@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Pass/Pass.h"
@@ -22,6 +23,9 @@ struct TestLinalgHoisting
: public PassWrapper<TestLinalgHoisting, FunctionPass> {
TestLinalgHoisting() = default;
TestLinalgHoisting(const TestLinalgHoisting &pass) {}
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<AffineDialect>();
+ }
void runOnFunction() override;
diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index f6c1160..b6f3cb4 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@@ -30,6 +31,14 @@ struct TestLinalgTransforms
TestLinalgTransforms() = default;
TestLinalgTransforms(const TestLinalgTransforms &pass) {}
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<AffineDialect>();
+ registry.insert<scf::SCFDialect>();
+ registry.insert<StandardOpsDialect>();
+ registry.insert<vector::VectorDialect>();
+ registry.insert<gpu::GPUDialect>();
+ }
+
void runOnFunction() override;
Option<bool> testPatterns{*this, "test-patterns",
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 9da3156..24e7a8c 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -8,6 +8,9 @@
#include <type_traits>
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
@@ -128,6 +131,13 @@ struct TestVectorTransferFullPartialSplitPatterns
TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns(
const TestVectorTransferFullPartialSplitPatterns &pass) {}
+
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<AffineDialect>();
+ registry.insert<linalg::LinalgDialect>();
+ registry.insert<scf::SCFDialect>();
+ }
+
Option<bool> useLinalgOps{
*this, "use-linalg-copy",
llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir
index f99a68d..4cf6ea9 100644
--- a/mlir/test/mlir-opt/commandline.mlir
+++ b/mlir/test/mlir-opt/commandline.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt --show-dialects | FileCheck %s
-// CHECK: Registered Dialects:
+// CHECK: Available Dialects:
// CHECK: affine
// CHECK: gpu
// CHECK: linalg
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 12e6aee..92efef6 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1703,7 +1703,7 @@ int main(int argc, char **argv) {
if (testEmitIncludeTdHeader)
output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\"";
- MLIRContext context;
+ MLIRContext context(/*loadAllDialects=*/false);
llvm::SourceMgr mgr;
mgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
Parser parser(mgr, &context);
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index efcb328..53ea4da 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -175,11 +175,10 @@ int main(int argc, char **argv) {
cl::ParseCommandLineOptions(argc, argv, "MLIR modular optimizer driver\n");
if(showDialects) {
- llvm::outs() << "Registered Dialects:\n";
- MLIRContext context;
- for(Dialect *dialect : context.getRegisteredDialects()) {
- llvm::outs() << dialect->getNamespace() << "\n";
- }
+ MLIRContext context(false);
+ registerAllDialects(&context);
+ llvm::outs() << "Available Dialects:\n";
+ interleave(context.getAvailableDialects(), llvm::outs(), "\n");
return 0;
}
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index 13421c4..797ecd7 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -61,11 +61,14 @@ filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) {
///
/// {0}: The name of the dialect class.
/// {1}: The dialect namespace.
+/// {2}: initialization code that is emitted in the ctor body before calling
+/// initialize()
static const char *const dialectDeclBeginStr = R"(
class {0} : public ::mlir::Dialect {
explicit {0}(::mlir::MLIRContext *context)
: ::mlir::Dialect(getDialectNamespace(), context,
::mlir::TypeID::get<{0}>()) {{
+ {2}
initialize();
}
void initialize();
@@ -74,6 +77,12 @@ public:
static ::llvm::StringRef getDialectNamespace() { return "{1}"; }
)";
+/// Registration for a single dependent dialect: to be inserted in the ctor
+/// above for each dependent dialect.
+const char *const dialectRegistrationTemplate = R"(
+ getContext()->getOrLoadDialect<{0}>();
+)";
+
/// The code block for the attribute parser/printer hooks.
static const char *const attrParserDecl = R"(
/// Parse an attribute registered to this dialect.
@@ -136,9 +145,18 @@ static void emitDialectDecl(Dialect &dialect,
iterator_range<DialectFilterIterator> dialectAttrs,
iterator_range<DialectFilterIterator> dialectTypes,
raw_ostream &os) {
+ /// Build the list of dependent dialects
+ std::string dependentDialectRegistrations;
+ {
+ llvm::raw_string_ostream dialects_os(dependentDialectRegistrations);
+ for (StringRef dependentDialect : dialect.getDependentDialects())
+ dialects_os << llvm::formatv(dialectRegistrationTemplate,
+ dependentDialect);
+ }
// Emit the start of the decl.
std::string cppName = dialect.getCppClassName();
- os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName());
+ os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
+ dependentDialectRegistrations);
// Check for any attributes/types registered to this dialect. If there are,
// add the hooks for parsing/printing.
diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp
index c2dcdb8..b6f5f3f 100644
--- a/mlir/tools/mlir-tblgen/PassGen.cpp
+++ b/mlir/tools/mlir-tblgen/PassGen.cpp
@@ -36,6 +36,7 @@ static llvm::cl::opt<std::string>
/// {0}: The def name of the pass record.
/// {1}: The base class for the pass.
/// {2): The command line argument for the pass.
+/// {3}: The dependent dialects registration.
const char *const passDeclBegin = R"(
//===----------------------------------------------------------------------===//
// {0}
@@ -63,9 +64,20 @@ public:
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
}
+ /// Return the dialect that must be loaded in the context before this pass.
+ void getDependentDialects(::mlir::DialectRegistry &registry) const override {
+ {3}
+ }
+
protected:
)";
+/// Registration for a single dependent dialect, to be inserted for each
+/// dependent dialect in the `getDependentDialects` above.
+const char *const dialectRegistrationTemplate = R"(
+ registry.insert<{0}>();
+)";
+
/// Emit the declarations for each of the pass options.
static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
for (const PassOption &opt : pass.getOptions()) {
@@ -94,8 +106,15 @@ static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
static void emitPassDecl(const Pass &pass, raw_ostream &os) {
StringRef defName = pass.getDef()->getName();
+ std::string dependentDialectRegistrations;
+ {
+ llvm::raw_string_ostream dialects_os(dependentDialectRegistrations);
+ for (StringRef dependentDialect : pass.getDependentDialects())
+ dialects_os << llvm::formatv(dialectRegistrationTemplate,
+ dependentDialect);
+ }
os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(),
- pass.getArgument());
+ pass.getArgument(), dependentDialectRegistrations);
emitPassOptionDecls(pass, os);
emitPassStatisticDecls(pass, os);
os << "};\n";
diff --git a/mlir/tools/mlir-translate/mlir-translate.cpp b/mlir/tools/mlir-translate/mlir-translate.cpp
index 914bd34..0d67286 100644
--- a/mlir/tools/mlir-translate/mlir-translate.cpp
+++ b/mlir/tools/mlir-translate/mlir-translate.cpp
@@ -88,7 +88,8 @@ int main(int argc, char **argv) {
// Processes the memory buffer with a new MLIRContext.
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
raw_ostream &os) {
- MLIRContext context;
+ MLIRContext context(false);
+ registerAllDialects(&context);
context.allowUnregisteredDialects();
context.printOpOnDiagnostic(!verifyDiagnostics);
llvm::SourceMgr sourceMgr;
diff --git a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
index 97c94a5..bae95e1 100644
--- a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
@@ -17,9 +17,6 @@
using namespace mlir;
using namespace mlir::quant;
-// Load the quant dialect
-static DialectRegistration<QuantizationDialect> QuantOpsRegistration;
-
namespace {
// Test UniformQuantizedValueConverter converts all APFloat to a magic number 5.
@@ -78,7 +75,8 @@ UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) {
}
TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
- MLIRContext ctx;
+ MLIRContext ctx(/*loadAllDialects=*/false);
+ ctx.getOrLoadDialect<QuantizationDialect>();
IntegerType convertedType = IntegerType::get(8, &ctx);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType);
@@ -95,7 +93,8 @@ TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
}
TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
- MLIRContext ctx;
+ MLIRContext ctx(/*loadAllDialects=*/false);
+ ctx.getOrLoadDialect<QuantizationDialect>();
IntegerType convertedType = IntegerType::get(8, &ctx);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType);
@@ -119,7 +118,8 @@ TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
}
TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
- MLIRContext ctx;
+ MLIRContext ctx(/*loadAllDialects=*/false);
+ ctx.getOrLoadDialect<QuantizationDialect>();
IntegerType convertedType = IntegerType::get(8, &ctx);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType);
@@ -143,7 +143,8 @@ TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
}
TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
- MLIRContext ctx;
+ MLIRContext ctx(/*loadAllDialects=*/false);
+ ctx.getOrLoadDialect<QuantizationDialect>();
IntegerType convertedType = IntegerType::get(8, &ctx);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType);
diff --git a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
index fe5632d..4aa2ffe 100644
--- a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
@@ -38,7 +38,8 @@ using ::testing::StrEq;
/// diagnostic checking utilities.
class DeserializationTest : public ::testing::Test {
protected:
- DeserializationTest() {
+ DeserializationTest() : context(/*loadAllDialects=*/false) {
+ context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
// Register a diagnostic handler to capture the diagnostic so that we can
// check it later.
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index 3d57e55..cb89cd6 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -36,7 +36,10 @@ using namespace mlir;
class SerializationTest : public ::testing::Test {
protected:
- SerializationTest() { createModuleOp(); }
+ SerializationTest() : context(/*loadAllDialects=*/false) {
+ context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
+ createModuleOp();
+ }
void createModuleOp() {
OpBuilder builder(&context);
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index df449a0..78f7dd5 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -32,7 +32,7 @@ static void testSplat(Type eltType, const EltTy &splatElt) {
namespace {
TEST(DenseSplatTest, BoolSplat) {
- MLIRContext context;
+ MLIRContext context(false);
IntegerType boolTy = IntegerType::get(1, &context);
RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
@@ -57,7 +57,7 @@ TEST(DenseSplatTest, BoolSplat) {
TEST(DenseSplatTest, LargeBoolSplat) {
constexpr int64_t boolCount = 56;
- MLIRContext context;
+ MLIRContext context(false);
IntegerType boolTy = IntegerType::get(1, &context);
RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
@@ -80,7 +80,7 @@ TEST(DenseSplatTest, LargeBoolSplat) {
}
TEST(DenseSplatTest, BoolNonSplat) {
- MLIRContext context;
+ MLIRContext context(false);
IntegerType boolTy = IntegerType::get(1, &context);
RankedTensorType shape = RankedTensorType::get({6}, boolTy);
@@ -92,7 +92,7 @@ TEST(DenseSplatTest, BoolNonSplat) {
TEST(DenseSplatTest, OddIntSplat) {
// Test detecting a splat with an odd(non 8-bit) integer bitwidth.
- MLIRContext context;
+ MLIRContext context(false);
constexpr size_t intWidth = 19;
IntegerType intTy = IntegerType::get(intWidth, &context);
APInt value(intWidth, 10);
@@ -101,7 +101,7 @@ TEST(DenseSplatTest, OddIntSplat) {
}
TEST(DenseSplatTest, Int32Splat) {
- MLIRContext context;
+ MLIRContext context(false);
IntegerType intTy = IntegerType::get(32, &context);
int value = 64;
@@ -109,7 +109,7 @@ TEST(DenseSplatTest, Int32Splat) {
}
TEST(DenseSplatTest, IntAttrSplat) {
- MLIRContext context;
+ MLIRContext context(false);
IntegerType intTy = IntegerType::get(85, &context);
Attribute value = IntegerAttr::get(intTy, 109);
@@ -117,7 +117,7 @@ TEST(DenseSplatTest, IntAttrSplat) {
}
TEST(DenseSplatTest, F32Splat) {
- MLIRContext context;
+ MLIRContext context(false);
FloatType floatTy = FloatType::getF32(&context);
float value = 10.0;
@@ -125,7 +125,7 @@ TEST(DenseSplatTest, F32Splat) {
}
TEST(DenseSplatTest, F64Splat) {
- MLIRContext context;
+ MLIRContext context(false);
FloatType floatTy = FloatType::getF64(&context);
double value = 10.0;
@@ -133,7 +133,7 @@ TEST(DenseSplatTest, F64Splat) {
}
TEST(DenseSplatTest, FloatAttrSplat) {
- MLIRContext context;
+ MLIRContext context(false);
FloatType floatTy = FloatType::getF32(&context);
Attribute value = FloatAttr::get(floatTy, 10.0);
@@ -141,7 +141,7 @@ TEST(DenseSplatTest, FloatAttrSplat) {
}
TEST(DenseSplatTest, BF16Splat) {
- MLIRContext context;
+ MLIRContext context(false);
FloatType floatTy = FloatType::getBF16(&context);
Attribute value = FloatAttr::get(floatTy, 10.0);
@@ -149,7 +149,7 @@ TEST(DenseSplatTest, BF16Splat) {
}
TEST(DenseSplatTest, StringSplat) {
- MLIRContext context;
+ MLIRContext context(false);
Type stringType =
OpaqueType::get(Identifier::get("test", &context), "string", &context);
StringRef value = "test-string";
@@ -157,7 +157,7 @@ TEST(DenseSplatTest, StringSplat) {
}
TEST(DenseSplatTest, StringAttrSplat) {
- MLIRContext context;
+ MLIRContext context(false);
Type stringType =
OpaqueType::get(Identifier::get("test", &context), "string", &context);
Attribute stringAttr = StringAttr::get("test-string", stringType);
@@ -165,28 +165,28 @@ TEST(DenseSplatTest, StringAttrSplat) {
}
TEST(DenseComplexTest, ComplexFloatSplat) {
- MLIRContext context;
+ MLIRContext context(false);
ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
std::complex<float> value(10.0, 15.0);
testSplat(complexType, value);
}
TEST(DenseComplexTest, ComplexIntSplat) {
- MLIRContext context;
+ MLIRContext context(false);
ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
std::complex<int64_t> value(10, 15);
testSplat(complexType, value);
}
TEST(DenseComplexTest, ComplexAPFloatSplat) {
- MLIRContext context;
+ MLIRContext context(false);
ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
testSplat(complexType, value);
}
TEST(DenseComplexTest, ComplexAPIntSplat) {
- MLIRContext context;
+ MLIRContext context(false);
ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
testSplat(complexType, value);
diff --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp
index bc389ce..c43fb77 100644
--- a/mlir/unittests/IR/DialectTest.cpp
+++ b/mlir/unittests/IR/DialectTest.cpp
@@ -26,12 +26,12 @@ struct AnotherTestDialect : public Dialect {
};
TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
- MLIRContext context;
+ MLIRContext context(false);
// Registering a dialect with the same namespace twice should result in a
// failure.
- context.getOrCreateDialect<TestDialect>();
- ASSERT_DEATH(context.getOrCreateDialect<AnotherTestDialect>(), "");
+ context.getOrLoadDialect<TestDialect>();
+ ASSERT_DEATH(context.getOrLoadDialect<AnotherTestDialect>(), "");
}
} // end namespace
diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index 95ddccc..9669330 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -25,7 +25,7 @@ static Operation *createOp(MLIRContext *context,
namespace {
TEST(OperandStorageTest, NonResizable) {
- MLIRContext context;
+ MLIRContext context(false);
Builder builder(&context);
Operation *useOp =
@@ -49,7 +49,7 @@ TEST(OperandStorageTest, NonResizable) {
}
TEST(OperandStorageTest, Resizable) {
- MLIRContext context;
+ MLIRContext context(false);
Builder builder(&context);
Operation *useOp =
@@ -77,7 +77,7 @@ TEST(OperandStorageTest, Resizable) {
}
TEST(OperandStorageTest, RangeReplace) {
- MLIRContext context;
+ MLIRContext context(false);
Builder builder(&context);
Operation *useOp =
@@ -113,7 +113,7 @@ TEST(OperandStorageTest, RangeReplace) {
}
TEST(OperandStorageTest, MutableRange) {
- MLIRContext context;
+ MLIRContext context(false);
Builder builder(&context);
Operation *useOp =
diff --git a/mlir/unittests/Pass/AnalysisManagerTest.cpp b/mlir/unittests/Pass/AnalysisManagerTest.cpp
index a99df39..958cf43 100644
--- a/mlir/unittests/Pass/AnalysisManagerTest.cpp
+++ b/mlir/unittests/Pass/AnalysisManagerTest.cpp
@@ -24,7 +24,7 @@ struct OtherAnalysis {
};
TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) {
- MLIRContext context;
+ MLIRContext context(false);
// Test fine grain invalidation of the module analysis manager.
OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
@@ -45,7 +45,7 @@ TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) {
}
TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) {
- MLIRContext context;
+ MLIRContext context(false);
Builder builder(&context);
// Create a function and a module.
@@ -74,7 +74,7 @@ TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) {
}
TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) {
- MLIRContext context;
+ MLIRContext context(false);
Builder builder(&context);
// Create a function and a module.
@@ -117,7 +117,7 @@ struct CustomInvalidatingAnalysis {
};
TEST(AnalysisManagerTest, CustomInvalidation) {
- MLIRContext context;
+ MLIRContext context(false);
Builder builder(&context);
// Create a function and a module.
diff --git a/mlir/unittests/SDBM/SDBMTest.cpp b/mlir/unittests/SDBM/SDBMTest.cpp
index 61d6706..bbe87e3 100644
--- a/mlir/unittests/SDBM/SDBMTest.cpp
+++ b/mlir/unittests/SDBM/SDBMTest.cpp
@@ -17,18 +17,17 @@
using namespace mlir;
-/// Load the SDBM dialect.
-static DialectRegistration<SDBMDialect> SDBMRegistration;
static MLIRContext *ctx() {
- static thread_local MLIRContext context;
+ static thread_local MLIRContext context(false);
+ context.getOrLoadDialect<SDBMDialect>();
return &context;
}
static SDBMDialect *dialect() {
static thread_local SDBMDialect *d = nullptr;
if (!d) {
- d = ctx()->getRegisteredDialect<SDBMDialect>();
+ d = ctx()->getOrLoadDialect<SDBMDialect>();
}
return d;
}
diff --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp
index 3e3256e..46a37da 100644
--- a/mlir/unittests/TableGen/OpBuildGen.cpp
+++ b/mlir/unittests/TableGen/OpBuildGen.cpp
@@ -25,11 +25,16 @@ namespace mlir {
// Test Fixture
//===----------------------------------------------------------------------===//
+static MLIRContext &getContext() {
+ static MLIRContext ctx(false);
+ ctx.getOrLoadDialect<TestDialect>();
+ return ctx;
+}
/// Test fixture for providing basic utilities for testing.
class OpBuildGenTest : public ::testing::Test {
protected:
OpBuildGenTest()
- : ctx{}, builder(&ctx), loc(builder.getUnknownLoc()),
+ : ctx(getContext()), builder(&ctx), loc(builder.getUnknownLoc()),
i32Ty(builder.getI32Type()), f32Ty(builder.getF32Type()),
cstI32(builder.create<TableGenConstant>(loc, i32Ty)),
cstF32(builder.create<TableGenConstant>(loc, f32Ty)),
@@ -86,7 +91,7 @@ protected:
}
protected:
- MLIRContext ctx;
+ MLIRContext &ctx;
OpBuilder builder;
Location loc;
Type i32Ty;
diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp
index c58fedb..14b0abc 100644
--- a/mlir/unittests/TableGen/StructsGenTest.cpp
+++ b/mlir/unittests/TableGen/StructsGenTest.cpp
@@ -42,7 +42,7 @@ static test::TestStruct getTestStruct(mlir::MLIRContext *context) {
/// Validates that test::TestStruct::classof correctly identifies a valid
/// test::TestStruct.
TEST(StructsGenTest, ClassofTrue) {
- mlir::MLIRContext context;
+ mlir::MLIRContext context(false);
auto structAttr = getTestStruct(&context);
ASSERT_TRUE(test::TestStruct::classof(structAttr));
}