diff --git a/src/smtlib/Typecheck.ml b/src/smtlib/Typecheck.ml index 51cc5722..9a67a726 100644 --- a/src/smtlib/Typecheck.ml +++ b/src/smtlib/Typecheck.ml @@ -29,10 +29,13 @@ module Ctx = struct and ty_kind = | K_atomic of Ty.def + type default_num = [`Real | `Int] + type t = { tst: T.store; names: (ID.t * kind) StrTbl.t; lets: T.t StrTbl.t; + mutable default_num: default_num; mutable loc: Loc.t option; (* current loc *) } @@ -40,9 +43,13 @@ module Ctx = struct tst; names=StrTbl.create 64; lets=StrTbl.create 16; + default_num=`Real; loc=None; } + let set_default_num_int self = self.default_num <- `Int + let set_default_num_real self = self.default_num <- `Real + let loc t = t.loc let set_loc ?loc t = t.loc <- loc @@ -137,6 +144,8 @@ let cast_to_real (ctx:Ctx.t) (t:T.t) : T.t = (* convert the whole structure to reals *) let l = LIA_view.to_lra conv l in T.lra ctx.tst l + | T.Ite (a,b,c) -> + T.ite ctx.tst a (conv b) (conv c) | _ -> errorf_ctx ctx "cannot cast term to real@ :term %a" T.pp t in @@ -227,9 +236,10 @@ let rec conv_term (ctx:Ctx.t) (t:PA.term) : T.t = | PA.True -> T.true_ tst | PA.False -> T.false_ tst | PA.Const s when is_num s -> - begin match string_as_z s with - | Some n -> T.lia tst (Const n) - | None -> + begin match string_as_z s, ctx.default_num with + | Some n, `Int -> T.lia tst (Const n) + | Some n, `Real -> T.lra tst (Const (Q.of_bigint n)) + | None, _ -> begin match string_as_q s with | Some n -> T.lra tst (Const n) | None -> errorf_ctx ctx "expected a number for %a" PA.pp_term t @@ -289,7 +299,11 @@ let rec conv_term (ctx:Ctx.t) (t:PA.term) : T.t = | PA.Eq (a,b) -> let a = conv_term ctx a in let b = conv_term ctx b in - Form.eq tst a b + if is_real a || is_real b then ( + Form.eq tst (cast_to_real ctx a) (cast_to_real ctx b) + ) else ( + Form.eq tst a b + ) | PA.Imply (a,b) -> let a = conv_term ctx a in let b = conv_term ctx b in @@ -427,6 +441,10 @@ let conv_fun_defs ctx decls bodies : A.definition list = defs *) +let is_lia s = + CCString.mem ~sub:"LIA" s || + CCString.mem ~sub:"LIRA" s + let rec conv_statement ctx (s:PA.statement): Stmt.t list = Log.debugf 4 (fun k->k "(@[<1>statement_of_ast@ %a@])" PA.pp_stmt s); Ctx.set_loc ctx ?loc:(PA.loc s); @@ -435,7 +453,13 @@ let rec conv_statement ctx (s:PA.statement): Stmt.t list = and conv_statement_aux ctx (stmt:PA.statement) : Stmt.t list = let tst = ctx.Ctx.tst in match PA.view stmt with - | PA.Stmt_set_logic s -> [Stmt.Stmt_set_logic s] + | PA.Stmt_set_logic logic -> + if is_lia logic then ( + Ctx.set_default_num_int ctx; + ) else ( + Ctx.set_default_num_real ctx; + ); + [Stmt.Stmt_set_logic logic] | PA.Stmt_set_option l -> [Stmt.Stmt_set_option l] | PA.Stmt_set_info (a,b) -> [Stmt.Stmt_set_info (a,b)] | PA.Stmt_exit -> [Stmt.Stmt_exit] diff --git a/src/smtlib/Typecheck.mli b/src/smtlib/Typecheck.mli index c526ce3a..ee2cc957 100644 --- a/src/smtlib/Typecheck.mli +++ b/src/smtlib/Typecheck.mli @@ -13,6 +13,8 @@ type 'a or_error = ('a, string) CCResult.t module Ctx : sig type t + val set_default_num_real : t -> unit + val set_default_num_int : t -> unit val create: T.store -> t end