diff options
Diffstat (limited to 'mlir/test/python/ir/dialects.py')
-rw-r--r-- | mlir/test/python/ir/dialects.py | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/mlir/test/python/ir/dialects.py b/mlir/test/python/ir/dialects.py index d59c6a6..5a2ed68 100644 --- a/mlir/test/python/ir/dialects.py +++ b/mlir/test/python/ir/dialects.py @@ -121,3 +121,39 @@ def testAppendPrefixSearchPath(): sys.path.append(".") _cext.globals.append_dialect_search_prefix("custom_dialect") assert _cext.globals._check_dialect_module_loaded("custom") + + +# CHECK-LABEL: TEST: testDialectLoadOnCreate +@run +def testDialectLoadOnCreate(): + with Context(load_on_create_dialects=[]) as ctx: + ctx.emit_error_diagnostics = True + ctx.allow_unregistered_dialects = True + + def callback(d): + # CHECK: DIAGNOSTIC + # CHECK-SAME: op created with unregistered dialect + print(f"DIAGNOSTIC={d.message}") + return True + + handler = ctx.attach_diagnostic_handler(callback) + loc = Location.unknown(ctx) + try: + op = Operation.create("arith.addi", loc=loc) + ctx.allow_unregistered_dialects = False + op.verify() + except MLIRError as e: + pass + + with Context(load_on_create_dialects=["func"]) as ctx: + loc = Location.unknown(ctx) + fn = Operation.create("func.func", loc=loc) + + # TODO: This may require an update if a site wide policy is set. + # CHECK: Load on create: [] + print(f"Load on create: {get_load_on_create_dialects()}") + append_load_on_create_dialect("func") + # CHECK: Load on create: + # CHECK-SAME: func + print(f"Load on create: {get_load_on_create_dialects()}") + print(get_load_on_create_dialects()) |