aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/python/dialects/transform_smt_ext.py
blob: 3692fd92344a6aea497ed18c5bcccecc92f206cb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# RUN: %PYTHON %s | FileCheck %s

from mlir import ir
from mlir.dialects import transform, smt
from mlir.dialects.transform import smt as transform_smt


def run(f):
    print("\nTEST:", f.__name__)
    with ir.Context(), ir.Location.unknown():
        module = ir.Module.create()
        with ir.InsertionPoint(module.body):
            sequence = transform.SequenceOp(
                transform.FailurePropagationMode.Propagate,
                [],
                transform.AnyOpType.get(),
            )
            with ir.InsertionPoint(sequence.body):
                f(sequence.bodyTarget)
                transform.YieldOp()
        print(module)
    return f


# CHECK-LABEL: TEST: testConstrainParamsOp
@run
def testConstrainParamsOp(target):
    dummy_value = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
    # CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
    symbolic_value = transform.ParamConstantOp(
        transform.AnyParamType.get(), dummy_value
    )
    # CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
    constrain_params = transform_smt.ConstrainParamsOp(
        [symbolic_value], [smt.IntType.get()]
    )
    # CHECK-NEXT: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
    with ir.InsertionPoint(constrain_params.body):
        # CHECK: %[[C0:.*]] = smt.int.constant 0
        c0 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0))
        # CHECK: %[[C43:.*]] = smt.int.constant 43
        c43 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 43))
        # CHECK: %[[LB:.*]] = smt.int.cmp le %[[C0]], %[[PARAM_AS_SMT_SYMB]]
        lb = smt.IntCmpOp(smt.IntPredicate.le, c0, constrain_params.body.arguments[0])
        # CHECK: %[[UB:.*]] = smt.int.cmp le %[[PARAM_AS_SMT_SYMB]], %[[C43]]
        ub = smt.IntCmpOp(smt.IntPredicate.le, constrain_params.body.arguments[0], c43)
        # CHECK: %[[BOUNDED:.*]] = smt.and %[[LB]], %[[UB]]
        bounded = smt.AndOp([lb, ub])
        # CHECK: smt.assert %[[BOUNDED:.*]]
        smt.AssertOp(bounded)