diff --git a/cvc5_pythonic_api/cvc5_pythonic.py b/cvc5_pythonic_api/cvc5_pythonic.py index ffc4da5..010a5fe 100644 --- a/cvc5_pythonic_api/cvc5_pythonic.py +++ b/cvc5_pythonic_api/cvc5_pythonic.py @@ -64,7 +64,6 @@ * FiniteDomainSort * Fixedpoint API * SMT2 file support - * recursive functions * Not missing, but different * Options * as expected @@ -150,6 +149,8 @@ def __init__(self): self.vars = {} # An increasing identifier used to make fresh identifiers self.next_fresh_var = 0 + # Function definitions to be added to a solver once it is created + self.defined_functions = [] def __del__(self): self.tm = None @@ -937,6 +938,50 @@ def FreshFunction(*sig): return Function(name, *sig) +def RecFunction(name, *sig): + """Create a new SMT uninterpreted function with the given sorts.""" + return Function(name, sig) + + +def RecAddDefinition(func, args, body): + """Define a new SMT recursive function with the given function declaration. + Replaces constants in `args` with bound variables. + + >>> fact = Function('fact', IntSort(), IntSort()) + >>> n = Int('n') + >>> RecAddDefinition(fact, n, If(n == 0, 1, n * fact(n - 1))) + >>> solve(Not(fact(5) == 120)) + unsat + """ + if is_app(args): + args = [args] + ctx = func.ctx + consts = [a.ast for a in args] + vars_ = [ctx.tm.mkVar(a.sort().ast, str(a)) for a in args] + subbed_body = body.ast.substitute(consts, vars_) + ctx.defined_functions.append(((func.ast, vars_, subbed_body), True)) + + +def AddDefinition(name, args, body): + """Define a new SMT function with the given function declaration. + Replaces constants in `args` with bound variables. + + >>> x, y = Ints('x y') + >>> minus = AddDefinition(minus, [x, y], x - y) + >>> solve(Not(minus(10, 5) == 5)) + unsat + """ + if is_app(args): + args = [args] + ctx = body.ctx + consts = [a.ast for a in args] + vars_ = [ctx.tm.mkVar(a.sort().ast, str(a)) for a in args] + subbed_body = body.ast.substitute(consts, vars_) + ctx.defined_functions.append( + ((name, vars_, subbed_body.getSort(), subbed_body), False) + ) + + ######################################### # # Expressions @@ -6006,6 +6051,16 @@ def initFromLogic(self): self.solver.setLogic(self.logic) self.solver.setOption("produce-models", "true") + def add_func_definitions(self): + """Add function definitions present in the current context""" + # FIXME: This is a temporary fix and should be removed once the base + # API have the proper solution in place. + for func in self.ctx.defined_functions: + if func[1]: + self.solver.defineFunRec(*func[0]) + else: + self.solver.defineFun(*func[0]) + def __del__(self): if self.solver is not None: self.solver = None @@ -6190,6 +6245,7 @@ def check(self, *assumptions): unsat >>> s.resetAssertions() """ + self.add_func_definitions() assumptions = _get_args(assumptions) r = CheckSatResult(self.solver.checkSatAssuming(*[a.ast for a in assumptions])) self.last_result = r