diff --git a/src/msat-solver/Sidekick_msat_solver.ml b/src/msat-solver/Sidekick_msat_solver.ml index 59f26dee..1417ce49 100644 --- a/src/msat-solver/Sidekick_msat_solver.ml +++ b/src/msat-solver/Sidekick_msat_solver.ml @@ -5,20 +5,27 @@ module Vec = Msat.Vec module Log = Msat.Log module IM = Util.Int_map +module CC_view = Sidekick_core.CC_view module type ARG = sig - include Sidekick_core.TERM_LIT_PROOF - val cc_view : Term.t -> (Fun.t, Term.t, Term.t Iter.t) Sidekick_core.CC_view.t + module A : Sidekick_core.CORE_TYPES + open A + val cc_view : Term.t -> (Fun.t, Term.t, Term.t Iter.t) CC_view.t end module type S = Sidekick_core.SOLVER -module Make(A : ARG) -(* : S with type A.Term.t = A.Term.t *) +module Make(Solver_arg : ARG) +(* : S with module A = Solver_arg.A *) = struct + module A = Solver_arg.A module T = A.Term + module Ty = A.Ty module Lit = A.Lit + type term = T.t + type ty = Ty.t type lit = Lit.t + type value = A.Value.t (** Custom keys for theory data. This imitates the classic tricks for heterogeneous maps @@ -64,28 +71,33 @@ module Make(A : ARG) to the congruence closure. *) module Key_set = struct type 'a key = 'a CC_key.t - type k1 = - | K1 : { + + type ke = + | KE : { k: 'a key; e: exn; - } -> k1 + } -> ke - type t = k1 IM.t + type t = ke IM.t let empty = IM.empty + let is_empty = IM.is_empty let[@inline] mem k t = IM.mem (CC_key.id k) t - let find (type a) (k : a key) (self:t) : a option = + (** Find the content for this key. + @raise Not_found if not present *) + let find (type a) (k : a key) (self:t) : a = let (module K) = k in match IM.find K.id self with - | K1 {e=K.Store v;_} -> Some v - | _ -> None - | exception Not_found -> None + | KE {e=K.Store v;_} -> v + | _ -> raise_notrace Not_found + + let[@inline] find_opt k self = try Some (find k self) with Not_found -> None let add (type a) (k : a key) (v:a) (self:t) : t = let (module K) = k in - IM.add K.id (K1 {k; e=K.Store v}) self + IM.add K.id (KE {k; e=K.Store v}) self let remove (type a) (k: a key) self : t = let (module K) = k in @@ -98,15 +110,34 @@ module Make(A : ARG) | None, None -> None | Some v, None | None, Some v -> Some v - | Some (K1 {k=(module K1) as key1; e=pair1; }), Some (K1{e=pair2;_}) -> + | Some (KE {k=(module KE) as key1; e=pair1; }), Some (KE{e=pair2;_}) -> match pair1, pair2 with - | K1.Store v1, K1.Store v2 -> - let v12 = K1.merge v1 v2 in (* merge content *) - Some (K1 {k=key1; e=K1.Store v12; }) + | KE.Store v1, KE.Store v2 -> + let v12 = KE.merge v1 v2 in (* merge content *) + Some (KE {k=key1; e=KE.Store v12; }) | _ -> assert false) m1 m2 - let pp_pair out (K1 {k=(module K);e=x; _}) = + type iter_fun = { + iter_fun: 'a. 'a key -> 'a -> 'a -> unit; + } [@@unboxed] + + let iter_inter (f: iter_fun) (m1:t) (m2:t) : unit = + if is_empty m1 || is_empty m2 then () + else ( + IM.iter + (fun i (KE {k=(module Key) as key;e=e1}) -> + match IM.find i m2 with + | KE {e=e2;_} -> + begin match e1, e2 with + | Key.Store x, Key.Store y -> f.iter_fun key x y + | _ -> assert false + end + | exception Not_found -> ()) + m1 + ) + + let pp_pair out (KE {k=(module K);e=x; _}) = match x with | K.Store x -> K.pp out x | _ -> assert false @@ -117,26 +148,28 @@ module Make(A : ARG) end (* the full argument to the congruence closure *) - module A = struct - include A + module CC_A = struct + include Solver_arg module Data = Key_set module Actions = struct type t = { - raise_conflict : 'a. Lit.t list -> Proof.t -> 'a; - propagate : Lit.t -> reason:Lit.t Iter.t -> Proof.t -> unit; + raise_conflict : 'a. Lit.t list -> A.Proof.t -> 'a; + propagate : Lit.t -> reason:Lit.t Iter.t -> A.Proof.t -> unit; } let[@inline] raise_conflict a lits p = a.raise_conflict lits p let[@inline] propagate a lit ~reason p = a.propagate lit ~reason p end end - module CC = Sidekick_cc.Make(A) + module CC = Sidekick_cc.Make(CC_A) module Expl = CC.Expl module N = CC.N (** Internal solver, given to theories and to Msat *) - module Solver_internal = struct + module Solver_internal + : Sidekick_core.SOLVER_INTERNAL with module A = A + = struct module A = A type th_states = @@ -163,8 +196,8 @@ module Make(A : ARG) mutable msat_acts: msat_acts option; mutable on_partial_check: (t -> lit Iter.t -> unit) list; mutable on_final_check: (t -> lit Iter.t -> unit) list; - mutable on_cc_merge: on_cc_merge list IM.t; - mutable on_cc_new_term : on_cc_new_term IM.t; + mutable on_cc_merge: on_cc_merge IM.t; + mutable on_cc_new_term : on_cc_new_term list; } and on_cc_merge = On_cc_merge : { @@ -193,6 +226,43 @@ module Make(A : ARG) let[@inline] cc (t:t) = Lazy.force t.cc let[@inline] tst t = t.tst + let on_cc_merge self ~k f = + self.on_cc_merge <- IM.add (CC_key.id k) (On_cc_merge{k;f}) self.on_cc_merge + + let on_cc_new_term self ~k f = + self.on_cc_new_term <- On_cc_new_term {k;f} :: self.on_cc_new_term + + let on_cc_merge_all self f = + (* just delegate this to the CC *) + CC.on_merge (cc self) (fun _cc n1 _th1 n2 _th2 expl -> f self n1 n2 expl) + + let handle_on_cc_merge (self:t) _cc n1 th1 n2 th2 expl : unit = + if Key_set.is_empty th1 || Key_set.is_empty th2 then () + else ( + (* iterate over the intersection of [th1] and [th2] *) + IM.iter + (fun _ (On_cc_merge {f;k}) -> + match Key_set.find k th1, Key_set.find k th2 with + | x1, x2 -> f self n1 x1 n2 x2 expl + | exception Not_found -> ()) + self.on_cc_merge + ) + + (* called by the CC when a term is added *) + let handle_on_cc_new_term (self:t) _cc n1 t1 : _ option = + match self.on_cc_new_term with + | [] -> None + | l -> + let map = + List.fold_left + (fun map (On_cc_new_term{k;f}) -> + match f self n1 t1 with + | None -> map + | Some u -> Key_set.add k u map) + Key_set.empty l + in + if Key_set.is_empty map then None else Some map + let[@inline] raise_conflict self c : 'a = Stat.incr self.count_conflict; match self.msat_acts with @@ -209,6 +279,16 @@ module Make(A : ARG) let[@inline] propagate_l self p cs : unit = propagate self p (fun()->cs) + let[@inline] cc_add_term self t = CC.add_term (cc self) t + let[@inline] cc_merge self n1 n2 e = CC.Theory.merge (cc self) n1 n2 e + let cc_merge_t self t1 t2 e = + let lazy cc = self.cc in + CC.Theory.merge cc (CC.add_term cc t1) (CC.add_term cc t2) e + + let cc_data self ~k n = + let data = N.th_data (CC.find (cc self) n) in + Key_set.find_opt k data + let add_axiom_ self ~keep lits : unit = Stat.incr self.count_axiom; match self.msat_acts with @@ -297,6 +377,9 @@ module Make(A : ARG) CC.set_as_lit cc n (Lit.abs lit); () + let on_final_check self f = self.on_final_check <- f :: self.on_final_check + let on_partial_check self f = self.on_partial_check <- f :: self.on_partial_check + (* propagation from the bool solver *) let[@inline] partial_check (self:t) (acts:_ Msat.acts) : unit = check_ ~final:false self acts @@ -343,11 +426,12 @@ module Make(A : ARG) on_partial_check=[]; on_final_check=[]; on_cc_merge=IM.empty; - on_cc_new_term=IM.empty; + on_cc_new_term=[]; } in - ignore (Lazy.force @@ self.cc : CC.t); + let lazy cc = self.cc in + CC.on_merge cc (handle_on_cc_merge self); + CC.on_new_term cc (handle_on_cc_new_term self); self - end type conflict = lit list @@ -360,12 +444,15 @@ module Make(A : ARG) module Atom = Sat_solver.Atom module Proof = Sat_solver.Proof + type proof = Proof.t (* main solver state *) type t = { si: Solver_internal.t; solver: Sat_solver.t; stat: Stat.t; + count_clause: int Stat.counter; + count_solve: int Stat.counter; (* config: Config.t *) } type solver = t @@ -380,6 +467,20 @@ module Make(A : ARG) type theory = (module THEORY) + let mk_theory (type st) + ~name ~create_and_setup + ?(push_level=fun _ -> ()) + ?(pop_levels=fun _ _ -> ()) + () : theory = + let module Th = struct + type t = st + let name = name + let create_and_setup = create_and_setup + let push_level = push_level + let pop_levels = pop_levels + end in + (module Th : THEORY) + (** {2 Main} *) let add_theory (self:t) (th:theory) : unit = @@ -409,6 +510,8 @@ module Make(A : ARG) si; solver=Sat_solver.create ?store_proof ?size si; stat; + count_clause=Stat.mk_int stat "solver-clauses"; + count_solve=Stat.mk_int stat "solver-solve"; } in add_theory_l self theories; (* assert [true] and [not false] *) @@ -435,6 +538,15 @@ module Make(A : ARG) let lit = Lit.atom (tst self) ?sign t in mk_atom_lit self lit + let add_clause self c = Sat_solver.add_clause_a self.solver (c:_ IArray.t:>_ array) A.Proof.default + let add_clause_l self l = add_clause self (IArray.of_list l) + + let add_clause_lits self l = + add_clause self @@ IArray.map (mk_atom_lit self) l + + let add_clause_lits_l self l = + add_clause self @@ IArray.of_list_map (mk_atom_lit self) l + (** {2 Result} *) module Unknown = struct @@ -464,7 +576,7 @@ module Make(A : ARG) let pp_model = Model.pp *) - type res = (Model.t, Proof.t, lit IArray.t, Unknown.t) Sidekick_core.solver_res + type res = (Model.t, Proof.t, unit ->lit IArray.t, Unknown.t) Sidekick_core.solver_res (** {2 Main} *) @@ -482,15 +594,15 @@ module Make(A : ARG) (* map boolean subterms to literals *) let add_bool_subterms_ (self:t) (t:T.t) : unit = - Term.iter_dag t - |> Iter.filter (fun t -> Ty.is_prop @@ Term.ty t) + T.iter_dag t + |> Iter.filter (fun t -> Ty.is_bool @@ T.ty t) |> Iter.filter - (fun t -> match Term.view t with - | Term.Not _ -> false (* will process the subterm just later *) + (fun t -> match CC_A.cc_view t with + | CC_view.Not _ -> false (* will process the subterm just later *) | _ -> true) |> Iter.iter (fun sub -> - Log.debugf 5 (fun k->k "(@[solver.map-to-lit@ :subterm %a@])" Term.pp sub); + Log.debugf 5 (fun k->k "(@[solver.map-to-lit@ :subterm %a@])" T.pp sub); ignore (mk_atom_t self sub : Sat_solver.atom)) let assume (self:t) (c:Lit.t IArray.t) : unit = @@ -498,7 +610,7 @@ module Make(A : ARG) IArray.iter (fun lit -> add_bool_subterms_ self @@ Lit.term lit) c; let c = IArray.to_array_map (Sat_solver.make_atom sat) c in Stat.incr self.count_clause; - Sat_solver.add_clause_a sat c Proof_default + Sat_solver.add_clause_a sat c A.Proof.default (* TODO: remove? use a special constant + micro theory instead? let[@inline] assume_distinct self l ~neq lit : unit = @@ -512,8 +624,6 @@ module Make(A : ARG) *) () - (* TODO: main loop with iterative deepening of the unrolling limit - (not the value depth limit) *) let solve ?(on_exit=[]) ?(check=true) ~assumptions (self:t) : res = let do_on_exit () = List.iter (fun f->f()) on_exit; @@ -523,17 +633,15 @@ module Make(A : ARG) match r with | Sat_solver.Sat st -> Log.debugf 1 (fun k->k "SAT"); - let lits f = st.iter_trail f (fun _ -> ()) in - let m = Theory_combine.mk_model (th_combine self) lits in + let _lits f = st.iter_trail f (fun _ -> ()) in + let m = Model.empty in + (* TODO Theory_combine.mk_model (th_combine self) lits *) do_on_exit (); Sat m - (* - let env = Ast.env_empty in - let m = Model.make ~env in - … - Unknown U_incomplete (* TODO *) - *) | Sat_solver.Unsat us -> + let uc () = + clause_of_mclause @@ us.Msat.unsat_conflict () + in let pr = try let pr = us.get_proof () in @@ -542,6 +650,6 @@ module Make(A : ARG) with Msat.Solver_intf.No_proof -> None in do_on_exit (); - Unsat pr + Unsat {proof=pr; unsat_core=uc} end