aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/python/dialects/smt.py
blob: 6f0cd8835b65b96cc231184df265adc832b0e6a8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# RUN: %PYTHON %s | FileCheck %s

from mlir.dialects import smt, arith
from mlir.ir import Context, Location, Module, InsertionPoint, F32Type


def run(f):
    print("\nTEST:", f.__name__)
    with Context(), Location.unknown():
        module = Module.create()
        with InsertionPoint(module.body):
            f(module)
        print(module)
        assert module.operation.verify()


# CHECK-LABEL: TEST: test_smoke
@run
def test_smoke(_module):
    true = smt.constant(True)
    false = smt.constant(False)
    # CHECK: smt.constant true
    # CHECK: smt.constant false


# CHECK-LABEL: TEST: test_types
@run
def test_types(_module):
    bool_t = smt.bool_t()
    bitvector_t = smt.bv_t(5)
    # CHECK: !smt.bool
    print(bool_t)
    # CHECK: !smt.bv<5>
    print(bitvector_t)


# CHECK-LABEL: TEST: test_solver_op
@run
def test_solver_op(_module):
    @smt.solver
    def foo1():
        true = smt.constant(True)
        false = smt.constant(False)

    # CHECK: smt.solver() : () -> () {
    # CHECK:   %true = smt.constant true
    # CHECK:   %false = smt.constant false
    # CHECK: }

    f32 = F32Type.get()

    @smt.solver(results=[f32])
    def foo2():
        return arith.ConstantOp(f32, 1.0)

    # CHECK: %{{.*}} = smt.solver() : () -> f32 {
    # CHECK:   %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
    # CHECK:   smt.yield %[[CST1]] : f32
    # CHECK: }

    two = arith.ConstantOp(f32, 2.0)
    # CHECK: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
    print(two)

    @smt.solver(inputs=[two], results=[f32])
    def foo3(z: f32):
        return z

    # CHECK: %{{.*}} = smt.solver(%[[CST2]]) : (f32) -> f32 {
    # CHECK: ^bb0(%[[ARG0:.*]]: f32):
    # CHECK:   smt.yield %[[ARG0]] : f32
    # CHECK: }


# CHECK-LABEL: TEST: test_export_smtlib
@run
def test_export_smtlib(module):
    @smt.solver
    def foo1():
        true = smt.constant(True)
        smt.assert_(true)

    query = smt.export_smtlib(module.operation)
    # CHECK: ; solver scope 0
    # CHECK: (assert true)
    # CHECK: (reset)
    print(query)