mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-10 21:24:06 -05:00
wip: msat solver
This commit is contained in:
parent
6ef3da9d02
commit
44259ec5fc
1 changed files with 155 additions and 47 deletions
|
|
@ -5,20 +5,27 @@
|
||||||
module Vec = Msat.Vec
|
module Vec = Msat.Vec
|
||||||
module Log = Msat.Log
|
module Log = Msat.Log
|
||||||
module IM = Util.Int_map
|
module IM = Util.Int_map
|
||||||
|
module CC_view = Sidekick_core.CC_view
|
||||||
|
|
||||||
module type ARG = sig
|
module type ARG = sig
|
||||||
include Sidekick_core.TERM_LIT_PROOF
|
module A : Sidekick_core.CORE_TYPES
|
||||||
val cc_view : Term.t -> (Fun.t, Term.t, Term.t Iter.t) Sidekick_core.CC_view.t
|
open A
|
||||||
|
val cc_view : Term.t -> (Fun.t, Term.t, Term.t Iter.t) CC_view.t
|
||||||
end
|
end
|
||||||
|
|
||||||
module type S = Sidekick_core.SOLVER
|
module type S = Sidekick_core.SOLVER
|
||||||
|
|
||||||
module Make(A : ARG)
|
module Make(Solver_arg : ARG)
|
||||||
(* : S with type A.Term.t = A.Term.t *)
|
(* : S with module A = Solver_arg.A *)
|
||||||
= struct
|
= struct
|
||||||
|
module A = Solver_arg.A
|
||||||
module T = A.Term
|
module T = A.Term
|
||||||
|
module Ty = A.Ty
|
||||||
module Lit = A.Lit
|
module Lit = A.Lit
|
||||||
|
type term = T.t
|
||||||
|
type ty = Ty.t
|
||||||
type lit = Lit.t
|
type lit = Lit.t
|
||||||
|
type value = A.Value.t
|
||||||
|
|
||||||
(** Custom keys for theory data.
|
(** Custom keys for theory data.
|
||||||
This imitates the classic tricks for heterogeneous maps
|
This imitates the classic tricks for heterogeneous maps
|
||||||
|
|
@ -64,28 +71,33 @@ module Make(A : ARG)
|
||||||
to the congruence closure. *)
|
to the congruence closure. *)
|
||||||
module Key_set = struct
|
module Key_set = struct
|
||||||
type 'a key = 'a CC_key.t
|
type 'a key = 'a CC_key.t
|
||||||
type k1 =
|
|
||||||
| K1 : {
|
type ke =
|
||||||
|
| KE : {
|
||||||
k: 'a key;
|
k: 'a key;
|
||||||
e: exn;
|
e: exn;
|
||||||
} -> k1
|
} -> ke
|
||||||
|
|
||||||
type t = k1 IM.t
|
type t = ke IM.t
|
||||||
|
|
||||||
let empty = IM.empty
|
let empty = IM.empty
|
||||||
|
let is_empty = IM.is_empty
|
||||||
|
|
||||||
let[@inline] mem k t = IM.mem (CC_key.id k) t
|
let[@inline] mem k t = IM.mem (CC_key.id k) t
|
||||||
|
|
||||||
let find (type a) (k : a key) (self:t) : a option =
|
(** Find the content for this key.
|
||||||
|
@raise Not_found if not present *)
|
||||||
|
let find (type a) (k : a key) (self:t) : a =
|
||||||
let (module K) = k in
|
let (module K) = k in
|
||||||
match IM.find K.id self with
|
match IM.find K.id self with
|
||||||
| K1 {e=K.Store v;_} -> Some v
|
| KE {e=K.Store v;_} -> v
|
||||||
| _ -> None
|
| _ -> raise_notrace Not_found
|
||||||
| exception Not_found -> None
|
|
||||||
|
let[@inline] find_opt k self = try Some (find k self) with Not_found -> None
|
||||||
|
|
||||||
let add (type a) (k : a key) (v:a) (self:t) : t =
|
let add (type a) (k : a key) (v:a) (self:t) : t =
|
||||||
let (module K) = k in
|
let (module K) = k in
|
||||||
IM.add K.id (K1 {k; e=K.Store v}) self
|
IM.add K.id (KE {k; e=K.Store v}) self
|
||||||
|
|
||||||
let remove (type a) (k: a key) self : t =
|
let remove (type a) (k: a key) self : t =
|
||||||
let (module K) = k in
|
let (module K) = k in
|
||||||
|
|
@ -98,15 +110,34 @@ module Make(A : ARG)
|
||||||
| None, None -> None
|
| None, None -> None
|
||||||
| Some v, None
|
| Some v, None
|
||||||
| None, Some v -> Some v
|
| None, Some v -> Some v
|
||||||
| Some (K1 {k=(module K1) as key1; e=pair1; }), Some (K1{e=pair2;_}) ->
|
| Some (KE {k=(module KE) as key1; e=pair1; }), Some (KE{e=pair2;_}) ->
|
||||||
match pair1, pair2 with
|
match pair1, pair2 with
|
||||||
| K1.Store v1, K1.Store v2 ->
|
| KE.Store v1, KE.Store v2 ->
|
||||||
let v12 = K1.merge v1 v2 in (* merge content *)
|
let v12 = KE.merge v1 v2 in (* merge content *)
|
||||||
Some (K1 {k=key1; e=K1.Store v12; })
|
Some (KE {k=key1; e=KE.Store v12; })
|
||||||
| _ -> assert false)
|
| _ -> assert false)
|
||||||
m1 m2
|
m1 m2
|
||||||
|
|
||||||
let pp_pair out (K1 {k=(module K);e=x; _}) =
|
type iter_fun = {
|
||||||
|
iter_fun: 'a. 'a key -> 'a -> 'a -> unit;
|
||||||
|
} [@@unboxed]
|
||||||
|
|
||||||
|
let iter_inter (f: iter_fun) (m1:t) (m2:t) : unit =
|
||||||
|
if is_empty m1 || is_empty m2 then ()
|
||||||
|
else (
|
||||||
|
IM.iter
|
||||||
|
(fun i (KE {k=(module Key) as key;e=e1}) ->
|
||||||
|
match IM.find i m2 with
|
||||||
|
| KE {e=e2;_} ->
|
||||||
|
begin match e1, e2 with
|
||||||
|
| Key.Store x, Key.Store y -> f.iter_fun key x y
|
||||||
|
| _ -> assert false
|
||||||
|
end
|
||||||
|
| exception Not_found -> ())
|
||||||
|
m1
|
||||||
|
)
|
||||||
|
|
||||||
|
let pp_pair out (KE {k=(module K);e=x; _}) =
|
||||||
match x with
|
match x with
|
||||||
| K.Store x -> K.pp out x
|
| K.Store x -> K.pp out x
|
||||||
| _ -> assert false
|
| _ -> assert false
|
||||||
|
|
@ -117,26 +148,28 @@ module Make(A : ARG)
|
||||||
end
|
end
|
||||||
|
|
||||||
(* the full argument to the congruence closure *)
|
(* the full argument to the congruence closure *)
|
||||||
module A = struct
|
module CC_A = struct
|
||||||
include A
|
include Solver_arg
|
||||||
|
|
||||||
module Data = Key_set
|
module Data = Key_set
|
||||||
module Actions = struct
|
module Actions = struct
|
||||||
type t = {
|
type t = {
|
||||||
raise_conflict : 'a. Lit.t list -> Proof.t -> 'a;
|
raise_conflict : 'a. Lit.t list -> A.Proof.t -> 'a;
|
||||||
propagate : Lit.t -> reason:Lit.t Iter.t -> Proof.t -> unit;
|
propagate : Lit.t -> reason:Lit.t Iter.t -> A.Proof.t -> unit;
|
||||||
}
|
}
|
||||||
let[@inline] raise_conflict a lits p = a.raise_conflict lits p
|
let[@inline] raise_conflict a lits p = a.raise_conflict lits p
|
||||||
let[@inline] propagate a lit ~reason p = a.propagate lit ~reason p
|
let[@inline] propagate a lit ~reason p = a.propagate lit ~reason p
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
module CC = Sidekick_cc.Make(A)
|
module CC = Sidekick_cc.Make(CC_A)
|
||||||
module Expl = CC.Expl
|
module Expl = CC.Expl
|
||||||
module N = CC.N
|
module N = CC.N
|
||||||
|
|
||||||
(** Internal solver, given to theories and to Msat *)
|
(** Internal solver, given to theories and to Msat *)
|
||||||
module Solver_internal = struct
|
module Solver_internal
|
||||||
|
: Sidekick_core.SOLVER_INTERNAL with module A = A
|
||||||
|
= struct
|
||||||
module A = A
|
module A = A
|
||||||
|
|
||||||
type th_states =
|
type th_states =
|
||||||
|
|
@ -163,8 +196,8 @@ module Make(A : ARG)
|
||||||
mutable msat_acts: msat_acts option;
|
mutable msat_acts: msat_acts option;
|
||||||
mutable on_partial_check: (t -> lit Iter.t -> unit) list;
|
mutable on_partial_check: (t -> lit Iter.t -> unit) list;
|
||||||
mutable on_final_check: (t -> lit Iter.t -> unit) list;
|
mutable on_final_check: (t -> lit Iter.t -> unit) list;
|
||||||
mutable on_cc_merge: on_cc_merge list IM.t;
|
mutable on_cc_merge: on_cc_merge IM.t;
|
||||||
mutable on_cc_new_term : on_cc_new_term IM.t;
|
mutable on_cc_new_term : on_cc_new_term list;
|
||||||
}
|
}
|
||||||
|
|
||||||
and on_cc_merge = On_cc_merge : {
|
and on_cc_merge = On_cc_merge : {
|
||||||
|
|
@ -193,6 +226,43 @@ module Make(A : ARG)
|
||||||
let[@inline] cc (t:t) = Lazy.force t.cc
|
let[@inline] cc (t:t) = Lazy.force t.cc
|
||||||
let[@inline] tst t = t.tst
|
let[@inline] tst t = t.tst
|
||||||
|
|
||||||
|
let on_cc_merge self ~k f =
|
||||||
|
self.on_cc_merge <- IM.add (CC_key.id k) (On_cc_merge{k;f}) self.on_cc_merge
|
||||||
|
|
||||||
|
let on_cc_new_term self ~k f =
|
||||||
|
self.on_cc_new_term <- On_cc_new_term {k;f} :: self.on_cc_new_term
|
||||||
|
|
||||||
|
let on_cc_merge_all self f =
|
||||||
|
(* just delegate this to the CC *)
|
||||||
|
CC.on_merge (cc self) (fun _cc n1 _th1 n2 _th2 expl -> f self n1 n2 expl)
|
||||||
|
|
||||||
|
let handle_on_cc_merge (self:t) _cc n1 th1 n2 th2 expl : unit =
|
||||||
|
if Key_set.is_empty th1 || Key_set.is_empty th2 then ()
|
||||||
|
else (
|
||||||
|
(* iterate over the intersection of [th1] and [th2] *)
|
||||||
|
IM.iter
|
||||||
|
(fun _ (On_cc_merge {f;k}) ->
|
||||||
|
match Key_set.find k th1, Key_set.find k th2 with
|
||||||
|
| x1, x2 -> f self n1 x1 n2 x2 expl
|
||||||
|
| exception Not_found -> ())
|
||||||
|
self.on_cc_merge
|
||||||
|
)
|
||||||
|
|
||||||
|
(* called by the CC when a term is added *)
|
||||||
|
let handle_on_cc_new_term (self:t) _cc n1 t1 : _ option =
|
||||||
|
match self.on_cc_new_term with
|
||||||
|
| [] -> None
|
||||||
|
| l ->
|
||||||
|
let map =
|
||||||
|
List.fold_left
|
||||||
|
(fun map (On_cc_new_term{k;f}) ->
|
||||||
|
match f self n1 t1 with
|
||||||
|
| None -> map
|
||||||
|
| Some u -> Key_set.add k u map)
|
||||||
|
Key_set.empty l
|
||||||
|
in
|
||||||
|
if Key_set.is_empty map then None else Some map
|
||||||
|
|
||||||
let[@inline] raise_conflict self c : 'a =
|
let[@inline] raise_conflict self c : 'a =
|
||||||
Stat.incr self.count_conflict;
|
Stat.incr self.count_conflict;
|
||||||
match self.msat_acts with
|
match self.msat_acts with
|
||||||
|
|
@ -209,6 +279,16 @@ module Make(A : ARG)
|
||||||
|
|
||||||
let[@inline] propagate_l self p cs : unit = propagate self p (fun()->cs)
|
let[@inline] propagate_l self p cs : unit = propagate self p (fun()->cs)
|
||||||
|
|
||||||
|
let[@inline] cc_add_term self t = CC.add_term (cc self) t
|
||||||
|
let[@inline] cc_merge self n1 n2 e = CC.Theory.merge (cc self) n1 n2 e
|
||||||
|
let cc_merge_t self t1 t2 e =
|
||||||
|
let lazy cc = self.cc in
|
||||||
|
CC.Theory.merge cc (CC.add_term cc t1) (CC.add_term cc t2) e
|
||||||
|
|
||||||
|
let cc_data self ~k n =
|
||||||
|
let data = N.th_data (CC.find (cc self) n) in
|
||||||
|
Key_set.find_opt k data
|
||||||
|
|
||||||
let add_axiom_ self ~keep lits : unit =
|
let add_axiom_ self ~keep lits : unit =
|
||||||
Stat.incr self.count_axiom;
|
Stat.incr self.count_axiom;
|
||||||
match self.msat_acts with
|
match self.msat_acts with
|
||||||
|
|
@ -297,6 +377,9 @@ module Make(A : ARG)
|
||||||
CC.set_as_lit cc n (Lit.abs lit);
|
CC.set_as_lit cc n (Lit.abs lit);
|
||||||
()
|
()
|
||||||
|
|
||||||
|
let on_final_check self f = self.on_final_check <- f :: self.on_final_check
|
||||||
|
let on_partial_check self f = self.on_partial_check <- f :: self.on_partial_check
|
||||||
|
|
||||||
(* propagation from the bool solver *)
|
(* propagation from the bool solver *)
|
||||||
let[@inline] partial_check (self:t) (acts:_ Msat.acts) : unit =
|
let[@inline] partial_check (self:t) (acts:_ Msat.acts) : unit =
|
||||||
check_ ~final:false self acts
|
check_ ~final:false self acts
|
||||||
|
|
@ -343,11 +426,12 @@ module Make(A : ARG)
|
||||||
on_partial_check=[];
|
on_partial_check=[];
|
||||||
on_final_check=[];
|
on_final_check=[];
|
||||||
on_cc_merge=IM.empty;
|
on_cc_merge=IM.empty;
|
||||||
on_cc_new_term=IM.empty;
|
on_cc_new_term=[];
|
||||||
} in
|
} in
|
||||||
ignore (Lazy.force @@ self.cc : CC.t);
|
let lazy cc = self.cc in
|
||||||
|
CC.on_merge cc (handle_on_cc_merge self);
|
||||||
|
CC.on_new_term cc (handle_on_cc_new_term self);
|
||||||
self
|
self
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
type conflict = lit list
|
type conflict = lit list
|
||||||
|
|
@ -360,12 +444,15 @@ module Make(A : ARG)
|
||||||
|
|
||||||
module Atom = Sat_solver.Atom
|
module Atom = Sat_solver.Atom
|
||||||
module Proof = Sat_solver.Proof
|
module Proof = Sat_solver.Proof
|
||||||
|
type proof = Proof.t
|
||||||
|
|
||||||
(* main solver state *)
|
(* main solver state *)
|
||||||
type t = {
|
type t = {
|
||||||
si: Solver_internal.t;
|
si: Solver_internal.t;
|
||||||
solver: Sat_solver.t;
|
solver: Sat_solver.t;
|
||||||
stat: Stat.t;
|
stat: Stat.t;
|
||||||
|
count_clause: int Stat.counter;
|
||||||
|
count_solve: int Stat.counter;
|
||||||
(* config: Config.t *)
|
(* config: Config.t *)
|
||||||
}
|
}
|
||||||
type solver = t
|
type solver = t
|
||||||
|
|
@ -380,6 +467,20 @@ module Make(A : ARG)
|
||||||
|
|
||||||
type theory = (module THEORY)
|
type theory = (module THEORY)
|
||||||
|
|
||||||
|
let mk_theory (type st)
|
||||||
|
~name ~create_and_setup
|
||||||
|
?(push_level=fun _ -> ())
|
||||||
|
?(pop_levels=fun _ _ -> ())
|
||||||
|
() : theory =
|
||||||
|
let module Th = struct
|
||||||
|
type t = st
|
||||||
|
let name = name
|
||||||
|
let create_and_setup = create_and_setup
|
||||||
|
let push_level = push_level
|
||||||
|
let pop_levels = pop_levels
|
||||||
|
end in
|
||||||
|
(module Th : THEORY)
|
||||||
|
|
||||||
(** {2 Main} *)
|
(** {2 Main} *)
|
||||||
|
|
||||||
let add_theory (self:t) (th:theory) : unit =
|
let add_theory (self:t) (th:theory) : unit =
|
||||||
|
|
@ -409,6 +510,8 @@ module Make(A : ARG)
|
||||||
si;
|
si;
|
||||||
solver=Sat_solver.create ?store_proof ?size si;
|
solver=Sat_solver.create ?store_proof ?size si;
|
||||||
stat;
|
stat;
|
||||||
|
count_clause=Stat.mk_int stat "solver-clauses";
|
||||||
|
count_solve=Stat.mk_int stat "solver-solve";
|
||||||
} in
|
} in
|
||||||
add_theory_l self theories;
|
add_theory_l self theories;
|
||||||
(* assert [true] and [not false] *)
|
(* assert [true] and [not false] *)
|
||||||
|
|
@ -435,6 +538,15 @@ module Make(A : ARG)
|
||||||
let lit = Lit.atom (tst self) ?sign t in
|
let lit = Lit.atom (tst self) ?sign t in
|
||||||
mk_atom_lit self lit
|
mk_atom_lit self lit
|
||||||
|
|
||||||
|
let add_clause self c = Sat_solver.add_clause_a self.solver (c:_ IArray.t:>_ array) A.Proof.default
|
||||||
|
let add_clause_l self l = add_clause self (IArray.of_list l)
|
||||||
|
|
||||||
|
let add_clause_lits self l =
|
||||||
|
add_clause self @@ IArray.map (mk_atom_lit self) l
|
||||||
|
|
||||||
|
let add_clause_lits_l self l =
|
||||||
|
add_clause self @@ IArray.of_list_map (mk_atom_lit self) l
|
||||||
|
|
||||||
(** {2 Result} *)
|
(** {2 Result} *)
|
||||||
|
|
||||||
module Unknown = struct
|
module Unknown = struct
|
||||||
|
|
@ -464,7 +576,7 @@ module Make(A : ARG)
|
||||||
let pp_model = Model.pp
|
let pp_model = Model.pp
|
||||||
*)
|
*)
|
||||||
|
|
||||||
type res = (Model.t, Proof.t, lit IArray.t, Unknown.t) Sidekick_core.solver_res
|
type res = (Model.t, Proof.t, unit ->lit IArray.t, Unknown.t) Sidekick_core.solver_res
|
||||||
|
|
||||||
(** {2 Main} *)
|
(** {2 Main} *)
|
||||||
|
|
||||||
|
|
@ -482,15 +594,15 @@ module Make(A : ARG)
|
||||||
|
|
||||||
(* map boolean subterms to literals *)
|
(* map boolean subterms to literals *)
|
||||||
let add_bool_subterms_ (self:t) (t:T.t) : unit =
|
let add_bool_subterms_ (self:t) (t:T.t) : unit =
|
||||||
Term.iter_dag t
|
T.iter_dag t
|
||||||
|> Iter.filter (fun t -> Ty.is_prop @@ Term.ty t)
|
|> Iter.filter (fun t -> Ty.is_bool @@ T.ty t)
|
||||||
|> Iter.filter
|
|> Iter.filter
|
||||||
(fun t -> match Term.view t with
|
(fun t -> match CC_A.cc_view t with
|
||||||
| Term.Not _ -> false (* will process the subterm just later *)
|
| CC_view.Not _ -> false (* will process the subterm just later *)
|
||||||
| _ -> true)
|
| _ -> true)
|
||||||
|> Iter.iter
|
|> Iter.iter
|
||||||
(fun sub ->
|
(fun sub ->
|
||||||
Log.debugf 5 (fun k->k "(@[solver.map-to-lit@ :subterm %a@])" Term.pp sub);
|
Log.debugf 5 (fun k->k "(@[solver.map-to-lit@ :subterm %a@])" T.pp sub);
|
||||||
ignore (mk_atom_t self sub : Sat_solver.atom))
|
ignore (mk_atom_t self sub : Sat_solver.atom))
|
||||||
|
|
||||||
let assume (self:t) (c:Lit.t IArray.t) : unit =
|
let assume (self:t) (c:Lit.t IArray.t) : unit =
|
||||||
|
|
@ -498,7 +610,7 @@ module Make(A : ARG)
|
||||||
IArray.iter (fun lit -> add_bool_subterms_ self @@ Lit.term lit) c;
|
IArray.iter (fun lit -> add_bool_subterms_ self @@ Lit.term lit) c;
|
||||||
let c = IArray.to_array_map (Sat_solver.make_atom sat) c in
|
let c = IArray.to_array_map (Sat_solver.make_atom sat) c in
|
||||||
Stat.incr self.count_clause;
|
Stat.incr self.count_clause;
|
||||||
Sat_solver.add_clause_a sat c Proof_default
|
Sat_solver.add_clause_a sat c A.Proof.default
|
||||||
|
|
||||||
(* TODO: remove? use a special constant + micro theory instead?
|
(* TODO: remove? use a special constant + micro theory instead?
|
||||||
let[@inline] assume_distinct self l ~neq lit : unit =
|
let[@inline] assume_distinct self l ~neq lit : unit =
|
||||||
|
|
@ -512,8 +624,6 @@ module Make(A : ARG)
|
||||||
*)
|
*)
|
||||||
()
|
()
|
||||||
|
|
||||||
(* TODO: main loop with iterative deepening of the unrolling limit
|
|
||||||
(not the value depth limit) *)
|
|
||||||
let solve ?(on_exit=[]) ?(check=true) ~assumptions (self:t) : res =
|
let solve ?(on_exit=[]) ?(check=true) ~assumptions (self:t) : res =
|
||||||
let do_on_exit () =
|
let do_on_exit () =
|
||||||
List.iter (fun f->f()) on_exit;
|
List.iter (fun f->f()) on_exit;
|
||||||
|
|
@ -523,17 +633,15 @@ module Make(A : ARG)
|
||||||
match r with
|
match r with
|
||||||
| Sat_solver.Sat st ->
|
| Sat_solver.Sat st ->
|
||||||
Log.debugf 1 (fun k->k "SAT");
|
Log.debugf 1 (fun k->k "SAT");
|
||||||
let lits f = st.iter_trail f (fun _ -> ()) in
|
let _lits f = st.iter_trail f (fun _ -> ()) in
|
||||||
let m = Theory_combine.mk_model (th_combine self) lits in
|
let m = Model.empty in
|
||||||
|
(* TODO Theory_combine.mk_model (th_combine self) lits *)
|
||||||
do_on_exit ();
|
do_on_exit ();
|
||||||
Sat m
|
Sat m
|
||||||
(*
|
|
||||||
let env = Ast.env_empty in
|
|
||||||
let m = Model.make ~env in
|
|
||||||
…
|
|
||||||
Unknown U_incomplete (* TODO *)
|
|
||||||
*)
|
|
||||||
| Sat_solver.Unsat us ->
|
| Sat_solver.Unsat us ->
|
||||||
|
let uc () =
|
||||||
|
clause_of_mclause @@ us.Msat.unsat_conflict ()
|
||||||
|
in
|
||||||
let pr =
|
let pr =
|
||||||
try
|
try
|
||||||
let pr = us.get_proof () in
|
let pr = us.get_proof () in
|
||||||
|
|
@ -542,6 +650,6 @@ module Make(A : ARG)
|
||||||
with Msat.Solver_intf.No_proof -> None
|
with Msat.Solver_intf.No_proof -> None
|
||||||
in
|
in
|
||||||
do_on_exit ();
|
do_on_exit ();
|
||||||
Unsat pr
|
Unsat {proof=pr; unsat_core=uc}
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue