diff --git a/src/smt/Congruence_closure.ml b/src/smt/Congruence_closure.ml index 97d6c847..2875defa 100644 --- a/src/smt/Congruence_closure.ml +++ b/src/smt/Congruence_closure.ml @@ -124,8 +124,7 @@ let signature cc (t:term): node term_cell option = | App_cst (f, a) -> App_cst (f, IArray.map find a) |> CCOpt.return | Custom {view;tc} -> Custom {tc; view=tc.tc_t_subst find view} |> CCOpt.return - | True - | Builtin _ + | Bool _ | If _ | Case _ -> None (* no congruence for these *) @@ -365,15 +364,16 @@ and add_new_term cc (t:term) : node = in (* register sub-terms, add [t] to their parent list *) begin match t.term_cell with - | True -> () + | Bool _-> () | App_cst (_, a) -> IArray.iter add_sub_t a | If (a,b,c) -> add_sub_t a; add_sub_t b; add_sub_t c | Case (u, _) -> add_sub_t u - | Builtin b -> Term.builtin_to_seq b add_sub_t - | Custom {view;tc} -> tc.tc_t_sub view add_sub_t + | Custom {view;tc} -> + (* add relevant subterms to the CC *) + tc.tc_t_relevant view add_sub_t end; (* remove term when we backtrack *) if not (cc.acts.at_lvl_0 ()) then ( @@ -501,10 +501,9 @@ let rec decompose_explain cc (e:explanation): unit = let l = r1.tc.tc_t_explain (same_class_t cc) r1.view r2.view in List.iter (fun (t,u) -> ps_add_obligation_t cc t u) l | If _, _ - | Builtin _, _ | App_cst _, _ | Case _, _ - | True, _ + | Bool _, _ | Custom _, _ -> assert false end diff --git a/src/smt/Lit.ml b/src/smt/Lit.ml index db4caee2..f13095b5 100644 --- a/src/smt/Lit.ml +++ b/src/smt/Lit.ml @@ -30,8 +30,6 @@ let atom ?(sign=true) (t:term) : t = let sign = if not sign' then not sign else sign in make ~sign (Lit_atom t) -let eq tst a b = atom ~sign:true (Term.eq tst a b) -let neq tst a b = atom ~sign:false (Term.eq tst a b) let expanded t = make ~sign:true (Lit_expanded t) let cstor_test tst cstor t = atom ~sign:true (Term.cstor_test tst cstor t) diff --git a/src/smt/Lit.mli b/src/smt/Lit.mli index 0269f3e4..1584d750 100644 --- a/src/smt/Lit.mli +++ b/src/smt/Lit.mli @@ -12,8 +12,6 @@ val fresh_with : ID.t -> t val fresh : unit -> t val dummy : t val atom : ?sign:bool -> term -> t -val eq : Term.state -> term -> term -> t -val neq : Term.state -> term -> term -> t val cstor_test : Term.state -> data_cstor -> term -> t val expanded : term -> t val hash : t -> int diff --git a/src/smt/Process.ml b/src/smt/Process.ml index 6f3d0abb..256073fb 100644 --- a/src/smt/Process.ml +++ b/src/smt/Process.ml @@ -38,7 +38,7 @@ module Conv = struct let conv_term (tst:Term.state) (t:A.term): Term.t = (* polymorphic equality *) - let mk_eq t u = Term.eq tst t u in + let mk_eq t u = Term.eq tst t u in (* TODO: use theory of booleans *) let mk_app f l = Term.app_cst tst f (IArray.of_list l) in let mk_const = Term.const tst in (* diff --git a/src/smt/Solver_types.ml b/src/smt/Solver_types.ml index 16175d48..2ab553c6 100644 --- a/src/smt/Solver_types.ml +++ b/src/smt/Solver_types.ml @@ -16,24 +16,15 @@ and term = { (* term shallow structure *) and 'a term_cell = - | True + | Bool of bool | App_cst of cst * 'a IArray.t (* full, first-order application *) | If of 'a * 'a * 'a | Case of 'a * 'a ID.Map.t (* check head constructor *) - | Builtin of 'a builtin | Custom of { view: 'a term_view_custom; tc: term_view_tc; } -and 'a builtin = - | B_not of 'a - | B_eq of 'a * 'a - | B_and of 'a list - | B_or of 'a list - | B_imply of 'a list * 'a - | B_distinct of 'a list - (** Methods on the custom term view whose leaves are ['a]. Terms must be comparable, hashable, printable, and provide some additional theory handles. @@ -63,6 +54,7 @@ and term_view_tc = { tc_t_is_semantic : 'a. 'a term_view_custom -> bool; (* is this a semantic term? semantic terms must be solvable *) tc_t_solve: cc_node term_view_custom -> cc_node term_view_custom -> solve_result; (* solve an equation between classes *) tc_t_sub : 'a. 'a term_view_custom -> 'a Sequence.t; (* iter on immediate subterms *) + tc_t_abs : 'a. self:'a -> 'a term_view_custom -> 'a * bool; (* remove the sign? *) tc_t_relevant : 'a. 'a term_view_custom -> 'a Sequence.t; (* iter on relevant immediate subterms *) tc_t_subst : 'a 'b. ('a -> 'b) -> 'a term_view_custom -> 'b term_view_custom; (* substitute immediate subterms and canonize *) tc_t_explain : 'a. 'a CCEqual.t -> 'a term_view_custom -> 'a term_view_custom -> ('a * 'a) list; @@ -286,7 +278,8 @@ let pp_term_top ~ids out t = () and pp_rec out t = match t.term_cell with - | True -> Fmt.string out "true" + | Bool true -> Fmt.string out "true" + | Bool false -> Fmt.string out "false" | App_cst (c, a) when IArray.is_empty a -> pp_id out (id_of_cst c) | App_cst (f,l) -> @@ -302,17 +295,6 @@ let pp_term_top ~ids out t = in Fmt.fprintf out "(@[match %a@ (@[%a@])@])" pp t print_map (ID.Map.to_seq m) - | Builtin (B_not t) -> Fmt.fprintf out "(@[not@ %a@])" pp t - | Builtin (B_and l) -> - Fmt.fprintf out "(@[and@ %a])" (Util.pp_list pp) l - | Builtin (B_or l) -> - Fmt.fprintf out "(@[or@ %a@])" (Util.pp_list pp) l - | Builtin (B_imply (a,b)) -> - Fmt.fprintf out "(@[=>@ %a@ %a@])" (Util.pp_list pp) a pp b - | Builtin (B_eq (a,b)) -> - Fmt.fprintf out "(@[=@ %a@ %a@])" pp a pp b - | Builtin (B_distinct l) -> - Fmt.fprintf out "(@[distinct@ %a@])" (Util.pp_list pp) l | Custom {view; tc} -> tc.tc_t_pp pp out view and pp_id = if ids then ID.pp else ID.pp_name diff --git a/src/smt/Term.ml b/src/smt/Term.ml index eaa01804..470268be 100644 --- a/src/smt/Term.ml +++ b/src/smt/Term.ml @@ -3,6 +3,23 @@ open Solver_types type t = term +type 'a custom = 'a Solver_types.term_view_custom = .. + +type tc = Solver_types.term_view_tc = { + tc_t_pp : 'a. 'a Fmt.printer -> 'a custom Fmt.printer; + tc_t_equal : 'a. 'a CCEqual.t -> 'a custom CCEqual.t; + tc_t_hash : 'a. 'a Hash.t -> 'a custom Hash.t; + tc_t_ty : 'a. ('a -> ty) -> 'a custom -> ty; + tc_t_is_semantic : 'a. 'a custom -> bool; + tc_t_solve : cc_node custom -> cc_node custom -> solve_result; + tc_t_sub : 'a. 'a custom -> 'a Sequence.t; + tc_t_abs : 'a. self:'a -> 'a custom -> 'a * bool; + tc_t_relevant : 'a. 'a custom -> 'a Sequence.t; + tc_t_subst : 'a 'b. ('a -> 'b) -> 'a custom -> 'b custom; + tc_t_explain : 'a. 'a CCEqual.t -> 'a custom -> 'a custom -> ('a * 'a) list; +} + + let[@inline] id t = t.term_id let[@inline] ty t = t.term_ty let[@inline] cell t = t.term_cell @@ -41,7 +58,7 @@ let create ?(size=1024) () : state = n=2; tbl=Term_cell.Tbl.create size; true_ = lazy (make st Term_cell.true_); - false_ = lazy (make st (Term_cell.not_ (true_ st))); + false_ = lazy (make st Term_cell.false_); } in ignore (Lazy.force st.true_); ignore (Lazy.force st.false_); (* not true *) @@ -59,71 +76,21 @@ let case st u m = make st (Term_cell.case u m) let if_ st a b c = make st (Term_cell.if_ a b c) -let not_ st t = make st (Term_cell.not_ t) - -let and_l st = function - | [] -> true_ st - | [t] -> t - | l -> make st (Term_cell.and_ l) - -let or_l st = function - | [] -> false_ st - | [t] -> t - | l -> make st (Term_cell.or_ l) - -let and_ st a b = and_l st [a;b] -let or_ st a b = or_l st [a;b] -let imply st a b = match a, b.term_cell with - | [], _ -> b - | _::_, Builtin (B_imply (a',b')) -> - make st (Term_cell.imply (CCList.append a a') b') - | _ -> make st (Term_cell.imply a b) -let eq st a b = make st (Term_cell.eq a b) -let distinct st l = make st (Term_cell.distinct l) -let neq st a b = make st (Term_cell.neq a b) -let builtin st b = make st (Term_cell.builtin b) - (* "eager" and, evaluating [a] first *) let and_eager st a b = if_ st a b (false_ st) +let custom st ~tc view = make st (Term_cell.custom ~tc view) + let cstor_test st cstor t = make st (Term_cell.cstor_test cstor t) let cstor_proj st cstor i t = make st (Term_cell.cstor_proj cstor i t) (* might need to tranfer the negation from [t] to [sign] *) let abs t : t * bool = match t.term_cell with - | Builtin (B_not t) -> t, false + | Custom {view;tc} -> tc.tc_t_abs ~self:t view | _ -> t, true -let fold_map_builtin - (f:'a -> term -> 'a * term) (acc:'a) (b:t builtin): 'a * t builtin = - let fold_binary acc a b = - let acc, a = f acc a in - let acc, b = f acc b in - acc, a, b - in - match b with - | B_not t -> - let acc, t' = f acc t in - acc, B_not t' - | B_and l -> - let acc, l = CCList.fold_map f acc l in - acc, B_and l - | B_or l -> - let acc, l = CCList.fold_map f acc l in - acc, B_or l - | B_eq (a,b) -> - let acc, a, b = fold_binary acc a b in - acc, B_eq (a, b) - | B_distinct l -> - let acc, l = CCList.fold_map f acc l in - acc, B_distinct l - | B_imply (a,b) -> - let acc, a = CCList.fold_map f acc a in - let acc, b = f acc b in - acc, B_imply (a, b) - -let[@inline] is_true t = match t.term_cell with True -> true | _ -> false -let is_false t = match t.term_cell with Builtin (B_not u) -> is_true u | _ -> false +let[@inline] is_true t = match t.term_cell with Bool true -> true | _ -> false +let[@inline] is_false t = match t.term_cell with Bool false -> true | _ -> false let[@inline] is_const t = match t.term_cell with | App_cst (_, a) -> IArray.is_empty a @@ -137,16 +104,6 @@ let[@inline] is_semantic t = match t.term_cell with | Custom {view;tc} -> tc.tc_t_is_semantic view | _ -> false -let map_builtin f b = - let (), b = fold_map_builtin (fun () t -> (), f t) () b in - b - -let builtin_to_seq b yield = match b with - | B_not t -> yield t - | B_or l | B_and l | B_distinct l -> List.iter yield l - | B_imply (a,b) -> List.iter yield a; yield b - | B_eq (a,b) -> yield a; yield b - module As_key = struct type t = term let compare = compare @@ -161,13 +118,12 @@ let to_seq t yield = let rec aux t = yield t; match t.term_cell with - | True -> () + | Bool _ -> () | App_cst (_,a) -> IArray.iter aux a | If (a,b,c) -> aux a; aux b; aux c | Case (t, m) -> aux t; ID.Map.iter (fun _ rhs -> aux rhs) m - | Builtin b -> builtin_to_seq b aux | Custom {view;tc} -> tc.tc_t_sub view aux in aux t @@ -205,5 +161,5 @@ let pp = Solver_types.pp_term let dummy : t = { term_id= -1; term_ty=Ty.prop; - term_cell=True; + term_cell=Term_cell.true_; } diff --git a/src/smt/Term.mli b/src/smt/Term.mli index 747169b2..7e9628e4 100644 --- a/src/smt/Term.mli +++ b/src/smt/Term.mli @@ -3,6 +3,22 @@ open Solver_types type t = term +type 'a custom = 'a Solver_types.term_view_custom = .. + +type tc = Solver_types.term_view_tc = { + tc_t_pp : 'a. 'a Fmt.printer -> 'a custom Fmt.printer; + tc_t_equal : 'a. 'a CCEqual.t -> 'a custom CCEqual.t; + tc_t_hash : 'a. 'a Hash.t -> 'a custom Hash.t; + tc_t_ty : 'a. ('a -> ty) -> 'a custom -> ty; + tc_t_is_semantic : 'a. 'a custom -> bool; + tc_t_solve : cc_node custom -> cc_node custom -> solve_result; + tc_t_sub : 'a. 'a custom -> 'a Sequence.t; + tc_t_abs : 'a. self:'a -> 'a custom -> 'a * bool; + tc_t_relevant : 'a. 'a custom -> 'a Sequence.t; + tc_t_subst : 'a 'b. ('a -> 'b) -> 'a custom -> 'b custom; + tc_t_explain : 'a. 'a CCEqual.t -> 'a custom -> 'a custom -> ('a * 'a) list; +} + val id : t -> int val cell : t -> term term_cell val ty : t -> Ty.t @@ -14,33 +30,22 @@ type state val create : ?size:int -> unit -> state +val make : state -> t term_cell -> t val true_ : state -> t val false_ : state -> t val const : state -> cst -> t val app_cst : state -> cst -> t IArray.t -> t val if_: state -> t -> t -> t -> t val case : state -> t -> t ID.Map.t -> t -val builtin : state -> t builtin -> t -val and_ : state -> t -> t -> t -val or_ : state -> t -> t -> t -val not_ : state -> t -> t -val imply : state -> t list -> t -> t -val eq : state -> t -> t -> t -val neq : state -> t -> t -> t -val distinct : state -> t list -> t val and_eager : state -> t -> t -> t (* evaluate left argument first *) +val custom : state -> tc:tc -> t custom -> t val cstor_test : state -> data_cstor -> term -> t val cstor_proj : state -> data_cstor -> int -> term -> t -val and_l : state -> t list -> t -val or_l : state -> t list -> t - +(* TODO: remove *) val abs : t -> t * bool -val map_builtin : (t -> t) -> t builtin -> t builtin -val builtin_to_seq : t builtin -> t Sequence.t - val to_seq : t -> t Sequence.t val all_terms : state -> t Sequence.t diff --git a/src/smt/Term_cell.ml b/src/smt/Term_cell.ml index c04e68e3..94a2eb71 100644 --- a/src/smt/Term_cell.ml +++ b/src/smt/Term_cell.ml @@ -3,6 +3,33 @@ open Solver_types (* TODO: normalization of {!term_cell} for use in signatures? *) +type 'a cell = 'a Solver_types.term_cell = + | Bool of bool + | App_cst of cst * 'a IArray.t + | If of 'a * 'a * 'a + | Case of 'a * 'a ID.Map.t + | Custom of { + view : 'a term_view_custom; + tc : term_view_tc; + } + +type 'a custom = 'a Solver_types.term_view_custom = .. + +type tc = Solver_types.term_view_tc = { + tc_t_pp : 'a. 'a Fmt.printer -> 'a term_view_custom Fmt.printer; + tc_t_equal : 'a. 'a CCEqual.t -> 'a term_view_custom CCEqual.t; + tc_t_hash : 'a. 'a Hash.t -> 'a term_view_custom Hash.t; + tc_t_ty : 'a. ('a -> ty) -> 'a term_view_custom -> ty; + tc_t_is_semantic : 'a. 'a term_view_custom -> bool; + tc_t_solve : cc_node term_view_custom -> cc_node term_view_custom -> solve_result; + tc_t_sub : 'a. 'a term_view_custom -> 'a Sequence.t; + tc_t_abs : 'a. self:'a -> 'a custom -> 'a * bool; + tc_t_relevant : 'a. 'a term_view_custom -> 'a Sequence.t; + tc_t_subst : + 'a 'b. ('a -> 'b) -> 'a term_view_custom -> 'b term_view_custom; + tc_t_explain : 'a. 'a CCEqual.t -> 'a term_view_custom -> 'a term_view_custom -> ('a * 'a) list; +} + type t = term term_cell module type ARG = sig @@ -16,7 +43,7 @@ module Make_eq(A : ARG) = struct let sub_eq = A.equal let hash (t:A.t term_cell) : int = match t with - | True -> 1 + | Bool b -> Hash.bool b | App_cst (f,l) -> Hash.combine3 4 (Cst.hash f) (Hash.iarray sub_hash l) | If (a,b,c) -> Hash.combine4 7 (sub_hash a) (sub_hash b) (sub_hash c) @@ -25,17 +52,11 @@ module Make_eq(A : ARG) = struct Hash.seq (Hash.pair ID.hash sub_hash) (ID.Map.to_seq m) in Hash.combine3 8 (sub_hash u) hash_m - | Builtin (B_not a) -> Hash.combine2 20 (sub_hash a) - | Builtin (B_and l) -> Hash.combine2 21 (Hash.list sub_hash l) - | Builtin (B_or l) -> Hash.combine2 22 (Hash.list sub_hash l) - | Builtin (B_imply (l1,t2)) -> Hash.combine3 23 (Hash.list sub_hash l1) (sub_hash t2) - | Builtin (B_eq (t1,t2)) -> Hash.combine3 24 (sub_hash t1) (sub_hash t2) - | Builtin (B_distinct l) -> Hash.combine2 26 (Hash.list sub_hash l) | Custom {view;tc} -> tc.tc_t_hash sub_hash view (* equality that relies on physical equality of subterms *) let equal (a:A.t term_cell) b : bool = match a, b with - | True, True -> true + | Bool b1, Bool b2 -> CCBool.equal b1 b2 | App_cst (f1, a1), App_cst (f2, a2) -> Cst.equal f1 f2 && IArray.equal sub_eq a1 a2 | If (a1,b1,c1), If (a2,b2,c2) -> @@ -49,25 +70,12 @@ module Make_eq(A : ARG) = struct m1 && ID.Map.for_all (fun k2 _ -> ID.Map.mem k2 m1) m2 - | Builtin b1, Builtin b2 -> - begin match b1, b2 with - | B_not a1, B_not a2 -> sub_eq a1 a2 - | B_and l1, B_and l2 - | B_or l1, B_or l2 -> CCEqual.list sub_eq l1 l2 - | B_distinct l1, B_distinct l2 -> CCEqual.list sub_eq l1 l2 - | B_eq (a1,b1), B_eq (a2,b2) -> sub_eq a1 a2 && sub_eq b1 b2 - | B_imply (a1,b1), B_imply (a2,b2) -> CCEqual.list sub_eq a1 a2 && sub_eq b1 b2 - | B_not _, _ | B_and _, _ | B_eq _, _ - | B_or _, _ | B_imply _, _ | B_distinct _, _ - -> false - end | Custom r1, Custom r2 -> r1.tc.tc_t_equal sub_eq r1.view r2.view - | True, _ + | Bool _, _ | App_cst _, _ | If _, _ | Case _, _ - | Builtin _, _ | Custom _, _ -> false end[@@inline] @@ -78,7 +86,8 @@ include Make_eq(struct let hash (t:term): int = t.term_id end) -let true_ = True +let true_ = Bool true +let false_ = Bool false let app_cst f a = App_cst (f, a) let const c = App_cst (c, IArray.empty) @@ -95,29 +104,6 @@ let cstor_proj cstor i t = let p = IArray.get (Lazy.force cstor.cstor_proj) i in app_cst p (IArray.singleton t) -let builtin b = - let mk_ x = Builtin x in - (* normalize a bit *) - begin match b with - | B_imply ([], x) -> x.term_cell - | B_eq (a,b) when a.term_id = b.term_id -> true_ - | B_eq (a,b) when a.term_id > b.term_id -> mk_ @@ B_eq (b,a) - | _ -> mk_ b - end - -let not_ t = match t.term_cell with - | Builtin (B_not t') -> t'.term_cell - | _ -> builtin (B_not t) - -let and_ l = builtin (B_and l) -let or_ l = builtin (B_or l) -let imply a b = builtin (B_imply (a,b)) -let eq a b = builtin (B_eq (a,b)) -let distinct = function - | [] | [_] -> true_ - | l -> builtin (B_distinct l) -let neq a b = distinct [a;b] - let custom ~tc view = Custom {view;tc} (* type of an application *) @@ -130,7 +116,7 @@ let rec app_ty_ ty l : Ty.t = match Ty.view ty, l with assert false let ty (t:t): Ty.t = match t with - | True -> Ty.prop + | Bool _ -> Ty.prop | App_cst (f, a) -> let n_args, ret = Cst.ty f |> Ty.unfold_n in if n_args = IArray.length a @@ -143,7 +129,6 @@ let ty (t:t): Ty.t = match t with | Case (_,m) -> let _, rhs = ID.Map.choose m in rhs.term_ty - | Builtin _ -> Ty.prop | Custom {view;tc} -> tc.tc_t_ty (fun t -> t.term_ty) view module Tbl = CCHashtbl.Make(struct diff --git a/src/smt/Term_cell.mli b/src/smt/Term_cell.mli index 5dd09799..e1afa933 100644 --- a/src/smt/Term_cell.mli +++ b/src/smt/Term_cell.mli @@ -1,26 +1,47 @@ open Solver_types +type 'a cell = 'a Solver_types.term_cell = + | Bool of bool + | App_cst of cst * 'a IArray.t + | If of 'a * 'a * 'a + | Case of 'a * 'a ID.Map.t + | Custom of { + view : 'a term_view_custom; + tc : term_view_tc; + } + +type 'a custom = 'a Solver_types.term_view_custom = .. + +type tc = Solver_types.term_view_tc = { + tc_t_pp : 'a. 'a Fmt.printer -> 'a term_view_custom Fmt.printer; + tc_t_equal : 'a. 'a CCEqual.t -> 'a term_view_custom CCEqual.t; + tc_t_hash : 'a. 'a Hash.t -> 'a term_view_custom Hash.t; + tc_t_ty : 'a. ('a -> ty) -> 'a term_view_custom -> ty; + tc_t_is_semantic : 'a. 'a term_view_custom -> bool; + tc_t_solve : cc_node term_view_custom -> cc_node term_view_custom -> solve_result; + tc_t_sub : 'a. 'a term_view_custom -> 'a Sequence.t; + tc_t_abs : 'a. self:'a -> 'a custom -> 'a * bool; + tc_t_relevant : 'a. 'a term_view_custom -> 'a Sequence.t; + tc_t_subst : + 'a 'b. ('a -> 'b) -> 'a term_view_custom -> 'b term_view_custom; + tc_t_explain : 'a. 'a CCEqual.t -> 'a term_view_custom -> 'a term_view_custom -> ('a * 'a) list; +} + + type t = term term_cell val equal : t -> t -> bool val hash : t -> int val true_ : t +val false_ : t val const : cst -> t val app_cst : cst -> term IArray.t -> t val cstor_test : data_cstor -> term -> t val cstor_proj : data_cstor -> int -> term -> t val case : term -> term ID.Map.t -> t val if_ : term -> term -> term -> t -val builtin : term builtin -> t -val and_ : term list -> t -val or_ : term list -> t -val not_ : term -> t -val imply : term list -> term -> t -val eq : term -> term -> t -val neq : term -> term -> t -val distinct : term list -> t val custom : tc:term_view_tc -> term term_view_custom -> t val ty : t -> Ty.t diff --git a/src/smt/Theory.ml b/src/smt/Theory.ml index 0d5b122e..3ec2b03b 100644 --- a/src/smt/Theory.ml +++ b/src/smt/Theory.ml @@ -1,16 +1,18 @@ -(** Runtime state of a theory, with all the operations it provides *) -type state = { - on_merge: Equiv_class.t -> Equiv_class.t -> Explanation.t -> unit; +(** Runtime state of a theory, with all the operations it provides. + ['a] is the internal state *) +type state = State : { + mutable st: 'a; + on_merge: 'a -> Equiv_class.t -> Equiv_class.t -> Explanation.t -> unit; (** Called when two classes are merged *) - on_assert: Lit.t -> unit; + on_assert: 'a -> Lit.t -> unit; (** Called when a literal becomes true *) - final_check: Lit.t Sequence.t -> unit; + final_check: 'a -> Lit.t Sequence.t -> unit; (** Final check, must be complete (i.e. must raise a conflict if the set of literals is not satisfiable) *) -} +} -> state (** Unsatisfiable conjunction. Will be turned into a set of literals, whose negation becomes a @@ -59,8 +61,9 @@ type t = { let make ~name ~make () : t = {name;make} let make_st - ?(on_merge=fun _ _ _ -> ()) - ?(on_assert=fun _ -> ()) + ?(on_merge=fun _ _ _ _ -> ()) + ?(on_assert=fun _ _ -> ()) ~final_check + ~st () : state = - { on_merge; on_assert; final_check } + State { st; on_merge; on_assert; final_check } diff --git a/src/smt/Theory_combine.ml b/src/smt/Theory_combine.ml index 37c240a4..8435369f 100644 --- a/src/smt/Theory_combine.ml +++ b/src/smt/Theory_combine.ml @@ -56,12 +56,12 @@ let assume_lit (self:t) (lit:Lit.t) : unit = begin match Lit.view lit with | Lit_fresh _ -> () | Lit_expanded _ - | Lit_atom {term_cell=True; _} -> () - | Lit_atom t when Term.is_false t -> assert false + | Lit_atom {term_cell=Bool true; _} -> () + | Lit_atom {term_cell=Bool false; _} -> () | Lit_atom _ -> (* transmit to CC and theories *) Congruence_closure.assert_lit (cc self) lit; - theories self (fun th -> th.Theory.on_assert lit); + theories self (fun (Theory.State th) -> th.on_assert th.st lit); end (* push clauses from {!lemma_queue} into the slice *) @@ -138,7 +138,7 @@ let if_sat (self:t) (slice:_) : _ Sat_solver.res = in (* final check for each theory *) theories self - (fun th -> th.Theory.final_check forms); + (fun (Theory.State th) -> th.final_check th.st forms); cdcl_return_res self (** {2 Various helpers} *) @@ -163,7 +163,7 @@ let act_raise_conflict e = raise (Exn_conflict e) (* when CC decided to merge [r1] and [r2], notify theories *) let on_merge_from_cc (self:t) r1 r2 e : unit = theories self - (fun th -> th.Theory.on_merge r1 r2 e) + (fun (Theory.State th) -> th.on_merge th.st r1 r2 e) let mk_cc_actions (self:t) : Congruence_closure.actions = let Sat_solver.Actions r = self.cdcl_acts in diff --git a/src/smt/th_bool/Dagon_th_bool.ml b/src/smt/th_bool/Dagon_th_bool.ml new file mode 100644 index 00000000..75a7aed2 --- /dev/null +++ b/src/smt/th_bool/Dagon_th_bool.ml @@ -0,0 +1,244 @@ + +(** {1 Theory of Booleans} *) + +open Dagon_smt + +module Fmt = CCFormat + +type term = Term.t + +(* TODO (long term): relevancy propagation *) + + +(* TODO: migrate the boolean terms in there? *) +(* TODO: Tseitin on the fly when a composite boolean term is asserted. + --> maybe, cache the clause inside the literal *) + +(* TODO: in theory (or terms?) have a way to evaluate custom terms + (like formulas) in a given model, for checking models *) + +type 'a builtin = + | B_not of 'a + | B_eq of 'a * 'a + | B_and of 'a list + | B_or of 'a list + | B_imply of 'a list * 'a + | B_distinct of 'a list + +let fold_map_builtin + (f:'a -> 't -> 'b * 'u) (acc:'a) (b:'t builtin): 'b * 'u builtin = + let fold_binary acc a b = + let acc, a = f acc a in + let acc, b = f acc b in + acc, a, b + in + match b with + | B_not t -> + let acc, t' = f acc t in + acc, B_not t' + | B_and l -> + let acc, l = CCList.fold_map f acc l in + acc, B_and l + | B_or l -> + let acc, l = CCList.fold_map f acc l in + acc, B_or l + | B_eq (a,b) -> + let acc, a, b = fold_binary acc a b in + acc, B_eq (a, b) + | B_distinct l -> + let acc, l = CCList.fold_map f acc l in + acc, B_distinct l + | B_imply (a,b) -> + let acc, a = CCList.fold_map f acc a in + let acc, b = f acc b in + acc, B_imply (a, b) + +let map_builtin f b = + let (), b = fold_map_builtin (fun () t -> (), f t) () b in + b + +let builtin_to_seq b yield = match b with + | B_not t -> yield t + | B_or l | B_and l | B_distinct l -> List.iter yield l + | B_imply (a,b) -> List.iter yield a; yield b + | B_eq (a,b) -> yield a; yield b + +type 'a Term.custom += + | Builtin of { + view: 'a builtin; + (* TODO: bool value + explanation *) + (* TODO: caching of Tseiting *) + } + +module TC = struct + let hash sub_hash = function + | Builtin {view; _} -> + begin match view with + | B_not a -> Hash.combine2 20 (sub_hash a) + | B_and l -> Hash.combine2 21 (Hash.list sub_hash l) + | B_or l -> Hash.combine2 22 (Hash.list sub_hash l) + | B_imply (l1,t2) -> Hash.combine3 23 (Hash.list sub_hash l1) (sub_hash t2) + | B_eq (t1,t2) -> Hash.combine3 24 (sub_hash t1) (sub_hash t2) + | B_distinct l -> Hash.combine2 26 (Hash.list sub_hash l) + end + | _ -> assert false + + let eq sub_eq a b = match a, b with + | Builtin {view=b1; _}, Builtin {view=b2;_} -> + begin match b1, b2 with + | B_not a1, B_not a2 -> sub_eq a1 a2 + | B_and l1, B_and l2 + | B_or l1, B_or l2 -> CCEqual.list sub_eq l1 l2 + | B_distinct l1, B_distinct l2 -> CCEqual.list sub_eq l1 l2 + | B_eq (a1,b1), B_eq (a2,b2) -> sub_eq a1 a2 && sub_eq b1 b2 + | B_imply (a1,b1), B_imply (a2,b2) -> CCEqual.list sub_eq a1 a2 && sub_eq b1 b2 + | B_not _, _ | B_and _, _ | B_eq _, _ + | B_or _, _ | B_imply _, _ | B_distinct _, _ + -> false + end + | Builtin _, _ + | _, Builtin _ -> false + | _ -> assert false + + let pp sub_pp out = function + | Builtin {view=b;_} -> + begin match b with + | B_not t -> Fmt.fprintf out "(@[not@ %a@])" sub_pp t + | B_and l -> + Fmt.fprintf out "(@[and@ %a])" (Util.pp_list sub_pp) l + | B_or l -> + Fmt.fprintf out "(@[or@ %a@])" (Util.pp_list sub_pp) l + | B_imply (a,b) -> + Fmt.fprintf out "(@[=>@ %a@ %a@])" (Util.pp_list sub_pp) a sub_pp b + | B_eq (a,b) -> + Fmt.fprintf out "(@[=@ %a@ %a@])" sub_pp a sub_pp b + | B_distinct l -> + Fmt.fprintf out "(@[distinct@ %a@])" (Util.pp_list sub_pp) l + end + | _ -> assert false + + let get_ty _ = function + | Builtin _ -> Ty.prop + | _ -> assert false + + (* no Shostak for builtins, everything goes through clauses to + the SAT solver *) + let is_semantic = function + | Builtin {view=_;_} -> false + | _ -> assert false + + let solve _ _ = assert false (* never called *) + + let sub = function + | Builtin {view;_} -> builtin_to_seq view + | _ -> assert false + + let relevant = function + | Builtin _ -> Sequence.empty (* no congruence closure *) + | _ -> assert false + + let abs ~self = function + | Builtin {view=B_not b; _} -> b, false + | _ -> self, true + + let subst _ _ = assert false (* no congruence *) + + let explain _eq _ _ = assert false (* no congruence *) + + let tc : Term_cell.tc = { + Term_cell. + tc_t_pp = pp; + tc_t_equal = eq; + tc_t_hash = hash; + tc_t_ty = get_ty; + tc_t_is_semantic = is_semantic; + tc_t_solve = solve; + tc_t_sub = sub; + tc_t_abs = abs; + tc_t_relevant = relevant; + tc_t_subst = subst; + tc_t_explain = explain + } +end + +let tc = TC.tc + +module T_cell = struct + type t = Term_cell.t + + let builtin b = + let mk_ x = Term_cell.custom ~tc (Builtin {view=x}) in + (* normalize a bit *) + begin match b with + | B_imply ([], x) -> Term.cell x + | B_eq (a,b) when Term.equal a b -> Term_cell.true_ + | B_eq (a,b) when Term.id a > Term.id b -> mk_ @@ B_eq (b,a) + | _ -> mk_ b + end + + let not_ t = match Term.cell t with + | Term_cell.Custom {view=Builtin {view=B_not t';_};_} -> Term.cell t' + | _ -> builtin (B_not t) + + let and_ l = builtin (B_and l) + let or_ l = builtin (B_or l) + let imply a b = builtin (B_imply (a,b)) + let eq a b = builtin (B_eq (a,b)) + let distinct = function + | [] | [_] -> Term_cell.true_ + | l -> builtin (B_distinct l) + let neq a b = distinct [a;b] +end + +module T = struct + let make = Term.make + + let not_ st t = make st (T_cell.not_ t) + + let and_l st = function + | [] -> Term.true_ st + | [t] -> t + | l -> make st (T_cell.and_ l) + + let or_l st = function + | [] -> Term.false_ st + | [t] -> t + | l -> make st (T_cell.or_ l) + + let and_ st a b = and_l st [a;b] + let or_ st a b = or_l st [a;b] + let imply st a b = match a, Term.cell b with + | [], _ -> b + | _::_, Term_cell.Custom {view=Builtin {view=B_imply (a',b')}; _} -> + make st (T_cell.imply (CCList.append a a') b') + | _ -> make st (T_cell.imply a b) + let eq st a b = make st (T_cell.eq a b) + let distinct st l = make st (T_cell.distinct l) + let neq st a b = make st (T_cell.neq a b) + let builtin st b = make st (T_cell.builtin b) +end + +module Lit = struct + type t = Lit.t + let eq tst a b = Lit.atom ~sign:true (T.eq tst a b) + let neq tst a b = Lit.atom ~sign:false (T.eq tst a b) +end + +type t = { + tst: Term.state; + acts: Theory.actions; +} + +let on_assert (self:t) (lit:Lit.t) = + assert false (* TODO: see if Lit is a bool term, in which case Tseitin it *) + +let th = + let make tst acts = + let st = {tst;acts} in + Theory.make_st + ~on_assert + ~final_check:(fun _ _ -> ()) + ~st + () + in + Theory.make ~name:"boolean" ~make () diff --git a/src/smt/th_bool/Dagon_th_bool.mli b/src/smt/th_bool/Dagon_th_bool.mli new file mode 100644 index 00000000..7f40a4a8 --- /dev/null +++ b/src/smt/th_bool/Dagon_th_bool.mli @@ -0,0 +1,50 @@ + +(** {1 Theory of Booleans} *) + +open Dagon_smt + +type term = Term.t + +type 'a builtin = + | B_not of 'a + | B_eq of 'a * 'a + | B_and of 'a list + | B_or of 'a list + | B_imply of 'a list * 'a + | B_distinct of 'a list + +val map_builtin : ('a -> 'b) -> 'a builtin -> 'b builtin +val builtin_to_seq : 'a builtin -> 'a Sequence.t + +module T_cell : sig + type t = Term_cell.t + val builtin : term builtin -> t + val and_ : term list -> t + val or_ : term list -> t + val not_ : term -> t + val imply : term list -> term -> t + val eq : term -> term -> t + val neq : term -> term -> t + val distinct : term list -> t +end + +module T : sig + val builtin : Term.state -> term builtin -> term + val and_ : Term.state -> term -> term -> term + val or_ : Term.state -> term -> term -> term + val not_ : Term.state -> term -> term + val imply : Term.state -> term list -> term -> term + val eq : Term.state -> term -> term -> term + val neq : Term.state -> term -> term -> term + val distinct : Term.state -> term list -> term + val and_l : Term.state -> term list -> term + val or_l : Term.state -> term list -> term +end + +module Lit : sig + type t = Lit.t + val eq : Term.state -> term -> term -> t + val neq : Term.state -> term -> term -> t +end + +val th : Dagon_smt.Theory.t diff --git a/src/smt/th_bool/jbuild b/src/smt/th_bool/jbuild new file mode 100644 index 00000000..b5303589 --- /dev/null +++ b/src/smt/th_bool/jbuild @@ -0,0 +1,10 @@ +; vim:ft=lisp: +(library + ((name Dagon_th_bool) + (public_name dagon.th_bool) + (libraries (containers dagon.smt)) + (flags (:standard -w +a-4-44-48-58-60@8 + -color always -safe-string -short-paths -open Dagon_util)) + (ocamlopt_flags (:standard -O3 -color always + -unbox-closures -unbox-closures-factor 20)))) +