diff options
Diffstat (limited to 'mlir/test/python')
-rw-r--r-- | mlir/test/python/dialects/transform_tune_ext.py | 105 | ||||
-rw-r--r-- | mlir/test/python/ir/operation.py | 8 |
2 files changed, 99 insertions, 14 deletions
diff --git a/mlir/test/python/dialects/transform_tune_ext.py b/mlir/test/python/dialects/transform_tune_ext.py index dfb9359..eb2a083 100644 --- a/mlir/test/python/dialects/transform_tune_ext.py +++ b/mlir/test/python/dialects/transform_tune_ext.py @@ -1,21 +1,21 @@ # RUN: %PYTHON %s | FileCheck %s -from mlir.ir import * +from mlir import ir from mlir.dialects import transform from mlir.dialects.transform import tune, debug def run(f): - print("\nTEST:", f.__name__) - with Context(), Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): + print("\n// TEST:", f.__name__) + with ir.Context(), ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get(), ) - with InsertionPoint(sequence.body): + with ir.InsertionPoint(sequence.body): f(sequence.bodyTarget) transform.YieldOp() print(module) @@ -29,10 +29,10 @@ def testKnobOp(target): # CHECK: %[[HEADS_OR_TAILS:.*]] = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param heads_or_tails = tune.KnobOp( - result=any_param, name=StringAttr.get("coin"), options=[True, False] + result=any_param, name=ir.StringAttr.get("coin"), options=[True, False] ) # CHECK: transform.tune.knob<"animal"> options = ["cat", "dog", unit] -> !transform.any_param - tune.KnobOp(any_param, name="animal", options=["cat", "dog", UnitAttr.get()]) + tune.KnobOp(any_param, name="animal", options=["cat", "dog", ir.UnitAttr.get()]) # CHECK: transform.tune.knob<"tile_size"> options = [2, 4, 8, 16, 24, 32] -> !transform.any_param tune.KnobOp(any_param, "tile_size", [2, 4, 8, 16, 24, 32]) # CHECK: transform.tune.knob<"magic_value"> options = [2.000000e+00, 2.250000e+00, 2.500000e+00, 2.750000e+00, 3.000000e+00] -> !transform.any_param @@ -45,7 +45,10 @@ def testKnobOp(target): heads = tune.KnobOp(any_param, "coin", options=[True, False], selected=True) # CHECK: transform.tune.knob<"animal"> = "dog" from options = ["cat", "dog", unit] -> !transform.any_param tune.KnobOp( - any_param, name="animal", options=["cat", "dog", UnitAttr.get()], selected="dog" + any_param, + name="animal", + options=["cat", "dog", ir.UnitAttr.get()], + selected="dog", ) # CHECK: transform.tune.knob<"tile_size"> = 8 : i64 from options = [2, 4, 8, 16, 24, 32] -> !transform.any_param tune.KnobOp(any_param, "tile_size", [2, 4, 8, 16, 24, 32], selected=8) @@ -57,16 +60,90 @@ def testKnobOp(target): # CHECK: transform.tune.knob<"range_as_a_dict"> = 4 : i64 from options = {start = 2 : i64, step = 2 : i64, stop = 16 : i64} -> !transform.any_param # NB: Membership of `selected` in non-ArrayAttr `options` is _not_ verified. - i64 = IntegerType.get_signless(64) + i64 = ir.IntegerType.get_signless(64) tune.knob( any_param, "range_as_a_dict", - DictAttr.get( + ir.DictAttr.get( { - "start": IntegerAttr.get(i64, 2), - "stop": IntegerAttr.get(i64, 16), - "step": IntegerAttr.get(i64, 2), + "start": ir.IntegerAttr.get(i64, 2), + "stop": ir.IntegerAttr.get(i64, 16), + "step": ir.IntegerAttr.get(i64, 2), } ), selected=4, ) + + +# CHECK-LABEL: TEST: testAlternativesOp +@run +def testAlternativesOp(target): + any_param = transform.AnyParamType.get() + + # CHECK: %[[LEFT_OR_RIGHT_OUTCOME:.*]] = transform.tune.alternatives<"left_or_right"> -> !transform.any_param { + left_or_right = tune.AlternativesOp( + [transform.AnyParamType.get()], "left_or_right", 2 + ) + idx_for_left, idx_for_right = 0, 1 + with ir.InsertionPoint(left_or_right.alternatives[idx_for_left].blocks[0]): + # CHECK: %[[C0:.*]] = transform.param.constant 0 + i32_0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0) + c0 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_0) + # CHECK: transform.yield %[[C0]] + transform.yield_(c0) + # CHECK-NEXT: }, { + with ir.InsertionPoint(left_or_right.alternatives[idx_for_right].blocks[0]): + # CHECK: %[[C1:.*]] = transform.param.constant 1 + i32_1 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1) + c1 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1) + # CHECK: transform.yield %[[C1]] + transform.yield_(c1) + # CHECK-NEXT: } + outcome_of_left_or_right_decision = left_or_right.results[0] + + # CHECK: transform.tune.alternatives<"fork_in_the_road"> selected_region = 0 -> !transform.any_param { + fork_in_the_road = tune.AlternativesOp( + [transform.AnyParamType.get()], "fork_in_the_road", 2, selected_region=0 + ) + with ir.InsertionPoint(fork_in_the_road.alternatives[idx_for_left].blocks[0]): + # CHECK: %[[C0:.*]] = transform.param.constant 0 + i32_0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0) + c0 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_0) + # CHECK: transform.yield %[[C0]] + transform.yield_(c0) + # CHECK-NEXT: }, { + with ir.InsertionPoint(fork_in_the_road.alternatives[idx_for_right].blocks[0]): + # CHECK: %[[C1:.*]] = transform.param.constant 1 + i32_1 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1) + c1 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1) + # CHECK: transform.yield %[[C1]] + transform.yield_(c1) + # CHECK-NEXT: } + + # CHECK: transform.tune.alternatives<"left_or_right_as_before"> selected_region = %[[LEFT_OR_RIGHT_OUTCOME]] : !transform.any_param { + left_or_right_as_before = tune.AlternativesOp( + [], + "left_or_right_as_before", + 2, + selected_region=outcome_of_left_or_right_decision, + ) + with ir.InsertionPoint( + left_or_right_as_before.alternatives[idx_for_left].blocks[0] + ): + # CHECK: transform.param.constant 1337 + i32_1337 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1337) + c1337 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1337) + # CHECK: transform.debug.emit_param_as_remark + debug.emit_param_as_remark(c1337) + transform.yield_([]) + # CHECK-NEXT: }, { + with ir.InsertionPoint( + left_or_right_as_before.alternatives[idx_for_right].blocks[0] + ): + # CHECK: transform.param.constant 42 + i32_42 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42) + c42 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_42) + # CHECK: transform.debug.emit_param_as_remark + debug.emit_param_as_remark(c42) + transform.yield_([]) + # CHECK-NEXT: } diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index 4a3625c..cb4cfc8c 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -696,6 +696,7 @@ def testOperationPrint(): # CHECK: resource1: "0x08 module.operation.print(large_elements_limit=2) + # CHECK-LABEL: TEST: testKnownOpView @run def testKnownOpView(): @@ -969,6 +970,13 @@ def testOperationLoc(): assert op.location == loc assert op.operation.location == loc + another_loc = Location.name("another_loc") + op.location = another_loc + assert op.location == another_loc + assert op.operation.location == another_loc + # CHECK: loc("another_loc") + print(op.location) + # CHECK-LABEL: TEST: testModuleMerge @run |