mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-09 04:35:35 -05:00
wip: msat solver
This commit is contained in:
parent
6ef3da9d02
commit
44259ec5fc
1 changed files with 155 additions and 47 deletions
|
|
@ -5,20 +5,27 @@
|
|||
module Vec = Msat.Vec
|
||||
module Log = Msat.Log
|
||||
module IM = Util.Int_map
|
||||
module CC_view = Sidekick_core.CC_view
|
||||
|
||||
module type ARG = sig
|
||||
include Sidekick_core.TERM_LIT_PROOF
|
||||
val cc_view : Term.t -> (Fun.t, Term.t, Term.t Iter.t) Sidekick_core.CC_view.t
|
||||
module A : Sidekick_core.CORE_TYPES
|
||||
open A
|
||||
val cc_view : Term.t -> (Fun.t, Term.t, Term.t Iter.t) CC_view.t
|
||||
end
|
||||
|
||||
module type S = Sidekick_core.SOLVER
|
||||
|
||||
module Make(A : ARG)
|
||||
(* : S with type A.Term.t = A.Term.t *)
|
||||
module Make(Solver_arg : ARG)
|
||||
(* : S with module A = Solver_arg.A *)
|
||||
= struct
|
||||
module A = Solver_arg.A
|
||||
module T = A.Term
|
||||
module Ty = A.Ty
|
||||
module Lit = A.Lit
|
||||
type term = T.t
|
||||
type ty = Ty.t
|
||||
type lit = Lit.t
|
||||
type value = A.Value.t
|
||||
|
||||
(** Custom keys for theory data.
|
||||
This imitates the classic tricks for heterogeneous maps
|
||||
|
|
@ -64,28 +71,33 @@ module Make(A : ARG)
|
|||
to the congruence closure. *)
|
||||
module Key_set = struct
|
||||
type 'a key = 'a CC_key.t
|
||||
type k1 =
|
||||
| K1 : {
|
||||
|
||||
type ke =
|
||||
| KE : {
|
||||
k: 'a key;
|
||||
e: exn;
|
||||
} -> k1
|
||||
} -> ke
|
||||
|
||||
type t = k1 IM.t
|
||||
type t = ke IM.t
|
||||
|
||||
let empty = IM.empty
|
||||
let is_empty = IM.is_empty
|
||||
|
||||
let[@inline] mem k t = IM.mem (CC_key.id k) t
|
||||
|
||||
let find (type a) (k : a key) (self:t) : a option =
|
||||
(** Find the content for this key.
|
||||
@raise Not_found if not present *)
|
||||
let find (type a) (k : a key) (self:t) : a =
|
||||
let (module K) = k in
|
||||
match IM.find K.id self with
|
||||
| K1 {e=K.Store v;_} -> Some v
|
||||
| _ -> None
|
||||
| exception Not_found -> None
|
||||
| KE {e=K.Store v;_} -> v
|
||||
| _ -> raise_notrace Not_found
|
||||
|
||||
let[@inline] find_opt k self = try Some (find k self) with Not_found -> None
|
||||
|
||||
let add (type a) (k : a key) (v:a) (self:t) : t =
|
||||
let (module K) = k in
|
||||
IM.add K.id (K1 {k; e=K.Store v}) self
|
||||
IM.add K.id (KE {k; e=K.Store v}) self
|
||||
|
||||
let remove (type a) (k: a key) self : t =
|
||||
let (module K) = k in
|
||||
|
|
@ -98,15 +110,34 @@ module Make(A : ARG)
|
|||
| None, None -> None
|
||||
| Some v, None
|
||||
| None, Some v -> Some v
|
||||
| Some (K1 {k=(module K1) as key1; e=pair1; }), Some (K1{e=pair2;_}) ->
|
||||
| Some (KE {k=(module KE) as key1; e=pair1; }), Some (KE{e=pair2;_}) ->
|
||||
match pair1, pair2 with
|
||||
| K1.Store v1, K1.Store v2 ->
|
||||
let v12 = K1.merge v1 v2 in (* merge content *)
|
||||
Some (K1 {k=key1; e=K1.Store v12; })
|
||||
| KE.Store v1, KE.Store v2 ->
|
||||
let v12 = KE.merge v1 v2 in (* merge content *)
|
||||
Some (KE {k=key1; e=KE.Store v12; })
|
||||
| _ -> assert false)
|
||||
m1 m2
|
||||
|
||||
let pp_pair out (K1 {k=(module K);e=x; _}) =
|
||||
type iter_fun = {
|
||||
iter_fun: 'a. 'a key -> 'a -> 'a -> unit;
|
||||
} [@@unboxed]
|
||||
|
||||
let iter_inter (f: iter_fun) (m1:t) (m2:t) : unit =
|
||||
if is_empty m1 || is_empty m2 then ()
|
||||
else (
|
||||
IM.iter
|
||||
(fun i (KE {k=(module Key) as key;e=e1}) ->
|
||||
match IM.find i m2 with
|
||||
| KE {e=e2;_} ->
|
||||
begin match e1, e2 with
|
||||
| Key.Store x, Key.Store y -> f.iter_fun key x y
|
||||
| _ -> assert false
|
||||
end
|
||||
| exception Not_found -> ())
|
||||
m1
|
||||
)
|
||||
|
||||
let pp_pair out (KE {k=(module K);e=x; _}) =
|
||||
match x with
|
||||
| K.Store x -> K.pp out x
|
||||
| _ -> assert false
|
||||
|
|
@ -117,26 +148,28 @@ module Make(A : ARG)
|
|||
end
|
||||
|
||||
(* the full argument to the congruence closure *)
|
||||
module A = struct
|
||||
include A
|
||||
module CC_A = struct
|
||||
include Solver_arg
|
||||
|
||||
module Data = Key_set
|
||||
module Actions = struct
|
||||
type t = {
|
||||
raise_conflict : 'a. Lit.t list -> Proof.t -> 'a;
|
||||
propagate : Lit.t -> reason:Lit.t Iter.t -> Proof.t -> unit;
|
||||
raise_conflict : 'a. Lit.t list -> A.Proof.t -> 'a;
|
||||
propagate : Lit.t -> reason:Lit.t Iter.t -> A.Proof.t -> unit;
|
||||
}
|
||||
let[@inline] raise_conflict a lits p = a.raise_conflict lits p
|
||||
let[@inline] propagate a lit ~reason p = a.propagate lit ~reason p
|
||||
end
|
||||
end
|
||||
|
||||
module CC = Sidekick_cc.Make(A)
|
||||
module CC = Sidekick_cc.Make(CC_A)
|
||||
module Expl = CC.Expl
|
||||
module N = CC.N
|
||||
|
||||
(** Internal solver, given to theories and to Msat *)
|
||||
module Solver_internal = struct
|
||||
module Solver_internal
|
||||
: Sidekick_core.SOLVER_INTERNAL with module A = A
|
||||
= struct
|
||||
module A = A
|
||||
|
||||
type th_states =
|
||||
|
|
@ -163,8 +196,8 @@ module Make(A : ARG)
|
|||
mutable msat_acts: msat_acts option;
|
||||
mutable on_partial_check: (t -> lit Iter.t -> unit) list;
|
||||
mutable on_final_check: (t -> lit Iter.t -> unit) list;
|
||||
mutable on_cc_merge: on_cc_merge list IM.t;
|
||||
mutable on_cc_new_term : on_cc_new_term IM.t;
|
||||
mutable on_cc_merge: on_cc_merge IM.t;
|
||||
mutable on_cc_new_term : on_cc_new_term list;
|
||||
}
|
||||
|
||||
and on_cc_merge = On_cc_merge : {
|
||||
|
|
@ -193,6 +226,43 @@ module Make(A : ARG)
|
|||
let[@inline] cc (t:t) = Lazy.force t.cc
|
||||
let[@inline] tst t = t.tst
|
||||
|
||||
let on_cc_merge self ~k f =
|
||||
self.on_cc_merge <- IM.add (CC_key.id k) (On_cc_merge{k;f}) self.on_cc_merge
|
||||
|
||||
let on_cc_new_term self ~k f =
|
||||
self.on_cc_new_term <- On_cc_new_term {k;f} :: self.on_cc_new_term
|
||||
|
||||
let on_cc_merge_all self f =
|
||||
(* just delegate this to the CC *)
|
||||
CC.on_merge (cc self) (fun _cc n1 _th1 n2 _th2 expl -> f self n1 n2 expl)
|
||||
|
||||
let handle_on_cc_merge (self:t) _cc n1 th1 n2 th2 expl : unit =
|
||||
if Key_set.is_empty th1 || Key_set.is_empty th2 then ()
|
||||
else (
|
||||
(* iterate over the intersection of [th1] and [th2] *)
|
||||
IM.iter
|
||||
(fun _ (On_cc_merge {f;k}) ->
|
||||
match Key_set.find k th1, Key_set.find k th2 with
|
||||
| x1, x2 -> f self n1 x1 n2 x2 expl
|
||||
| exception Not_found -> ())
|
||||
self.on_cc_merge
|
||||
)
|
||||
|
||||
(* called by the CC when a term is added *)
|
||||
let handle_on_cc_new_term (self:t) _cc n1 t1 : _ option =
|
||||
match self.on_cc_new_term with
|
||||
| [] -> None
|
||||
| l ->
|
||||
let map =
|
||||
List.fold_left
|
||||
(fun map (On_cc_new_term{k;f}) ->
|
||||
match f self n1 t1 with
|
||||
| None -> map
|
||||
| Some u -> Key_set.add k u map)
|
||||
Key_set.empty l
|
||||
in
|
||||
if Key_set.is_empty map then None else Some map
|
||||
|
||||
let[@inline] raise_conflict self c : 'a =
|
||||
Stat.incr self.count_conflict;
|
||||
match self.msat_acts with
|
||||
|
|
@ -209,6 +279,16 @@ module Make(A : ARG)
|
|||
|
||||
let[@inline] propagate_l self p cs : unit = propagate self p (fun()->cs)
|
||||
|
||||
let[@inline] cc_add_term self t = CC.add_term (cc self) t
|
||||
let[@inline] cc_merge self n1 n2 e = CC.Theory.merge (cc self) n1 n2 e
|
||||
let cc_merge_t self t1 t2 e =
|
||||
let lazy cc = self.cc in
|
||||
CC.Theory.merge cc (CC.add_term cc t1) (CC.add_term cc t2) e
|
||||
|
||||
let cc_data self ~k n =
|
||||
let data = N.th_data (CC.find (cc self) n) in
|
||||
Key_set.find_opt k data
|
||||
|
||||
let add_axiom_ self ~keep lits : unit =
|
||||
Stat.incr self.count_axiom;
|
||||
match self.msat_acts with
|
||||
|
|
@ -297,6 +377,9 @@ module Make(A : ARG)
|
|||
CC.set_as_lit cc n (Lit.abs lit);
|
||||
()
|
||||
|
||||
let on_final_check self f = self.on_final_check <- f :: self.on_final_check
|
||||
let on_partial_check self f = self.on_partial_check <- f :: self.on_partial_check
|
||||
|
||||
(* propagation from the bool solver *)
|
||||
let[@inline] partial_check (self:t) (acts:_ Msat.acts) : unit =
|
||||
check_ ~final:false self acts
|
||||
|
|
@ -343,11 +426,12 @@ module Make(A : ARG)
|
|||
on_partial_check=[];
|
||||
on_final_check=[];
|
||||
on_cc_merge=IM.empty;
|
||||
on_cc_new_term=IM.empty;
|
||||
on_cc_new_term=[];
|
||||
} in
|
||||
ignore (Lazy.force @@ self.cc : CC.t);
|
||||
let lazy cc = self.cc in
|
||||
CC.on_merge cc (handle_on_cc_merge self);
|
||||
CC.on_new_term cc (handle_on_cc_new_term self);
|
||||
self
|
||||
|
||||
end
|
||||
|
||||
type conflict = lit list
|
||||
|
|
@ -360,12 +444,15 @@ module Make(A : ARG)
|
|||
|
||||
module Atom = Sat_solver.Atom
|
||||
module Proof = Sat_solver.Proof
|
||||
type proof = Proof.t
|
||||
|
||||
(* main solver state *)
|
||||
type t = {
|
||||
si: Solver_internal.t;
|
||||
solver: Sat_solver.t;
|
||||
stat: Stat.t;
|
||||
count_clause: int Stat.counter;
|
||||
count_solve: int Stat.counter;
|
||||
(* config: Config.t *)
|
||||
}
|
||||
type solver = t
|
||||
|
|
@ -380,6 +467,20 @@ module Make(A : ARG)
|
|||
|
||||
type theory = (module THEORY)
|
||||
|
||||
let mk_theory (type st)
|
||||
~name ~create_and_setup
|
||||
?(push_level=fun _ -> ())
|
||||
?(pop_levels=fun _ _ -> ())
|
||||
() : theory =
|
||||
let module Th = struct
|
||||
type t = st
|
||||
let name = name
|
||||
let create_and_setup = create_and_setup
|
||||
let push_level = push_level
|
||||
let pop_levels = pop_levels
|
||||
end in
|
||||
(module Th : THEORY)
|
||||
|
||||
(** {2 Main} *)
|
||||
|
||||
let add_theory (self:t) (th:theory) : unit =
|
||||
|
|
@ -409,6 +510,8 @@ module Make(A : ARG)
|
|||
si;
|
||||
solver=Sat_solver.create ?store_proof ?size si;
|
||||
stat;
|
||||
count_clause=Stat.mk_int stat "solver-clauses";
|
||||
count_solve=Stat.mk_int stat "solver-solve";
|
||||
} in
|
||||
add_theory_l self theories;
|
||||
(* assert [true] and [not false] *)
|
||||
|
|
@ -435,6 +538,15 @@ module Make(A : ARG)
|
|||
let lit = Lit.atom (tst self) ?sign t in
|
||||
mk_atom_lit self lit
|
||||
|
||||
let add_clause self c = Sat_solver.add_clause_a self.solver (c:_ IArray.t:>_ array) A.Proof.default
|
||||
let add_clause_l self l = add_clause self (IArray.of_list l)
|
||||
|
||||
let add_clause_lits self l =
|
||||
add_clause self @@ IArray.map (mk_atom_lit self) l
|
||||
|
||||
let add_clause_lits_l self l =
|
||||
add_clause self @@ IArray.of_list_map (mk_atom_lit self) l
|
||||
|
||||
(** {2 Result} *)
|
||||
|
||||
module Unknown = struct
|
||||
|
|
@ -464,7 +576,7 @@ module Make(A : ARG)
|
|||
let pp_model = Model.pp
|
||||
*)
|
||||
|
||||
type res = (Model.t, Proof.t, lit IArray.t, Unknown.t) Sidekick_core.solver_res
|
||||
type res = (Model.t, Proof.t, unit ->lit IArray.t, Unknown.t) Sidekick_core.solver_res
|
||||
|
||||
(** {2 Main} *)
|
||||
|
||||
|
|
@ -482,15 +594,15 @@ module Make(A : ARG)
|
|||
|
||||
(* map boolean subterms to literals *)
|
||||
let add_bool_subterms_ (self:t) (t:T.t) : unit =
|
||||
Term.iter_dag t
|
||||
|> Iter.filter (fun t -> Ty.is_prop @@ Term.ty t)
|
||||
T.iter_dag t
|
||||
|> Iter.filter (fun t -> Ty.is_bool @@ T.ty t)
|
||||
|> Iter.filter
|
||||
(fun t -> match Term.view t with
|
||||
| Term.Not _ -> false (* will process the subterm just later *)
|
||||
(fun t -> match CC_A.cc_view t with
|
||||
| CC_view.Not _ -> false (* will process the subterm just later *)
|
||||
| _ -> true)
|
||||
|> Iter.iter
|
||||
(fun sub ->
|
||||
Log.debugf 5 (fun k->k "(@[solver.map-to-lit@ :subterm %a@])" Term.pp sub);
|
||||
Log.debugf 5 (fun k->k "(@[solver.map-to-lit@ :subterm %a@])" T.pp sub);
|
||||
ignore (mk_atom_t self sub : Sat_solver.atom))
|
||||
|
||||
let assume (self:t) (c:Lit.t IArray.t) : unit =
|
||||
|
|
@ -498,7 +610,7 @@ module Make(A : ARG)
|
|||
IArray.iter (fun lit -> add_bool_subterms_ self @@ Lit.term lit) c;
|
||||
let c = IArray.to_array_map (Sat_solver.make_atom sat) c in
|
||||
Stat.incr self.count_clause;
|
||||
Sat_solver.add_clause_a sat c Proof_default
|
||||
Sat_solver.add_clause_a sat c A.Proof.default
|
||||
|
||||
(* TODO: remove? use a special constant + micro theory instead?
|
||||
let[@inline] assume_distinct self l ~neq lit : unit =
|
||||
|
|
@ -512,8 +624,6 @@ module Make(A : ARG)
|
|||
*)
|
||||
()
|
||||
|
||||
(* TODO: main loop with iterative deepening of the unrolling limit
|
||||
(not the value depth limit) *)
|
||||
let solve ?(on_exit=[]) ?(check=true) ~assumptions (self:t) : res =
|
||||
let do_on_exit () =
|
||||
List.iter (fun f->f()) on_exit;
|
||||
|
|
@ -523,17 +633,15 @@ module Make(A : ARG)
|
|||
match r with
|
||||
| Sat_solver.Sat st ->
|
||||
Log.debugf 1 (fun k->k "SAT");
|
||||
let lits f = st.iter_trail f (fun _ -> ()) in
|
||||
let m = Theory_combine.mk_model (th_combine self) lits in
|
||||
let _lits f = st.iter_trail f (fun _ -> ()) in
|
||||
let m = Model.empty in
|
||||
(* TODO Theory_combine.mk_model (th_combine self) lits *)
|
||||
do_on_exit ();
|
||||
Sat m
|
||||
(*
|
||||
let env = Ast.env_empty in
|
||||
let m = Model.make ~env in
|
||||
…
|
||||
Unknown U_incomplete (* TODO *)
|
||||
*)
|
||||
| Sat_solver.Unsat us ->
|
||||
let uc () =
|
||||
clause_of_mclause @@ us.Msat.unsat_conflict ()
|
||||
in
|
||||
let pr =
|
||||
try
|
||||
let pr = us.get_proof () in
|
||||
|
|
@ -542,6 +650,6 @@ module Make(A : ARG)
|
|||
with Msat.Solver_intf.No_proof -> None
|
||||
in
|
||||
do_on_exit ();
|
||||
Unsat pr
|
||||
Unsat {proof=pr; unsat_core=uc}
|
||||
|
||||
end
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue