wip: msat solver

This commit is contained in:
Simon Cruanes 2019-06-01 16:55:47 -05:00
parent 6ef3da9d02
commit 44259ec5fc

View file

@ -5,20 +5,27 @@
module Vec = Msat.Vec
module Log = Msat.Log
module IM = Util.Int_map
module CC_view = Sidekick_core.CC_view
module type ARG = sig
include Sidekick_core.TERM_LIT_PROOF
val cc_view : Term.t -> (Fun.t, Term.t, Term.t Iter.t) Sidekick_core.CC_view.t
module A : Sidekick_core.CORE_TYPES
open A
val cc_view : Term.t -> (Fun.t, Term.t, Term.t Iter.t) CC_view.t
end
module type S = Sidekick_core.SOLVER
module Make(A : ARG)
(* : S with type A.Term.t = A.Term.t *)
module Make(Solver_arg : ARG)
(* : S with module A = Solver_arg.A *)
= struct
module A = Solver_arg.A
module T = A.Term
module Ty = A.Ty
module Lit = A.Lit
type term = T.t
type ty = Ty.t
type lit = Lit.t
type value = A.Value.t
(** Custom keys for theory data.
This imitates the classic tricks for heterogeneous maps
@ -64,28 +71,33 @@ module Make(A : ARG)
to the congruence closure. *)
module Key_set = struct
type 'a key = 'a CC_key.t
type k1 =
| K1 : {
type ke =
| KE : {
k: 'a key;
e: exn;
} -> k1
} -> ke
type t = k1 IM.t
type t = ke IM.t
let empty = IM.empty
let is_empty = IM.is_empty
let[@inline] mem k t = IM.mem (CC_key.id k) t
let find (type a) (k : a key) (self:t) : a option =
(** Find the content for this key.
@raise Not_found if not present *)
let find (type a) (k : a key) (self:t) : a =
let (module K) = k in
match IM.find K.id self with
| K1 {e=K.Store v;_} -> Some v
| _ -> None
| exception Not_found -> None
| KE {e=K.Store v;_} -> v
| _ -> raise_notrace Not_found
let[@inline] find_opt k self = try Some (find k self) with Not_found -> None
let add (type a) (k : a key) (v:a) (self:t) : t =
let (module K) = k in
IM.add K.id (K1 {k; e=K.Store v}) self
IM.add K.id (KE {k; e=K.Store v}) self
let remove (type a) (k: a key) self : t =
let (module K) = k in
@ -98,15 +110,34 @@ module Make(A : ARG)
| None, None -> None
| Some v, None
| None, Some v -> Some v
| Some (K1 {k=(module K1) as key1; e=pair1; }), Some (K1{e=pair2;_}) ->
| Some (KE {k=(module KE) as key1; e=pair1; }), Some (KE{e=pair2;_}) ->
match pair1, pair2 with
| K1.Store v1, K1.Store v2 ->
let v12 = K1.merge v1 v2 in (* merge content *)
Some (K1 {k=key1; e=K1.Store v12; })
| KE.Store v1, KE.Store v2 ->
let v12 = KE.merge v1 v2 in (* merge content *)
Some (KE {k=key1; e=KE.Store v12; })
| _ -> assert false)
m1 m2
let pp_pair out (K1 {k=(module K);e=x; _}) =
type iter_fun = {
iter_fun: 'a. 'a key -> 'a -> 'a -> unit;
} [@@unboxed]
let iter_inter (f: iter_fun) (m1:t) (m2:t) : unit =
if is_empty m1 || is_empty m2 then ()
else (
IM.iter
(fun i (KE {k=(module Key) as key;e=e1}) ->
match IM.find i m2 with
| KE {e=e2;_} ->
begin match e1, e2 with
| Key.Store x, Key.Store y -> f.iter_fun key x y
| _ -> assert false
end
| exception Not_found -> ())
m1
)
let pp_pair out (KE {k=(module K);e=x; _}) =
match x with
| K.Store x -> K.pp out x
| _ -> assert false
@ -117,26 +148,28 @@ module Make(A : ARG)
end
(* the full argument to the congruence closure *)
module A = struct
include A
module CC_A = struct
include Solver_arg
module Data = Key_set
module Actions = struct
type t = {
raise_conflict : 'a. Lit.t list -> Proof.t -> 'a;
propagate : Lit.t -> reason:Lit.t Iter.t -> Proof.t -> unit;
raise_conflict : 'a. Lit.t list -> A.Proof.t -> 'a;
propagate : Lit.t -> reason:Lit.t Iter.t -> A.Proof.t -> unit;
}
let[@inline] raise_conflict a lits p = a.raise_conflict lits p
let[@inline] propagate a lit ~reason p = a.propagate lit ~reason p
end
end
module CC = Sidekick_cc.Make(A)
module CC = Sidekick_cc.Make(CC_A)
module Expl = CC.Expl
module N = CC.N
(** Internal solver, given to theories and to Msat *)
module Solver_internal = struct
module Solver_internal
: Sidekick_core.SOLVER_INTERNAL with module A = A
= struct
module A = A
type th_states =
@ -163,8 +196,8 @@ module Make(A : ARG)
mutable msat_acts: msat_acts option;
mutable on_partial_check: (t -> lit Iter.t -> unit) list;
mutable on_final_check: (t -> lit Iter.t -> unit) list;
mutable on_cc_merge: on_cc_merge list IM.t;
mutable on_cc_new_term : on_cc_new_term IM.t;
mutable on_cc_merge: on_cc_merge IM.t;
mutable on_cc_new_term : on_cc_new_term list;
}
and on_cc_merge = On_cc_merge : {
@ -193,6 +226,43 @@ module Make(A : ARG)
let[@inline] cc (t:t) = Lazy.force t.cc
let[@inline] tst t = t.tst
let on_cc_merge self ~k f =
self.on_cc_merge <- IM.add (CC_key.id k) (On_cc_merge{k;f}) self.on_cc_merge
let on_cc_new_term self ~k f =
self.on_cc_new_term <- On_cc_new_term {k;f} :: self.on_cc_new_term
let on_cc_merge_all self f =
(* just delegate this to the CC *)
CC.on_merge (cc self) (fun _cc n1 _th1 n2 _th2 expl -> f self n1 n2 expl)
let handle_on_cc_merge (self:t) _cc n1 th1 n2 th2 expl : unit =
if Key_set.is_empty th1 || Key_set.is_empty th2 then ()
else (
(* iterate over the intersection of [th1] and [th2] *)
IM.iter
(fun _ (On_cc_merge {f;k}) ->
match Key_set.find k th1, Key_set.find k th2 with
| x1, x2 -> f self n1 x1 n2 x2 expl
| exception Not_found -> ())
self.on_cc_merge
)
(* called by the CC when a term is added *)
let handle_on_cc_new_term (self:t) _cc n1 t1 : _ option =
match self.on_cc_new_term with
| [] -> None
| l ->
let map =
List.fold_left
(fun map (On_cc_new_term{k;f}) ->
match f self n1 t1 with
| None -> map
| Some u -> Key_set.add k u map)
Key_set.empty l
in
if Key_set.is_empty map then None else Some map
let[@inline] raise_conflict self c : 'a =
Stat.incr self.count_conflict;
match self.msat_acts with
@ -209,6 +279,16 @@ module Make(A : ARG)
let[@inline] propagate_l self p cs : unit = propagate self p (fun()->cs)
let[@inline] cc_add_term self t = CC.add_term (cc self) t
let[@inline] cc_merge self n1 n2 e = CC.Theory.merge (cc self) n1 n2 e
let cc_merge_t self t1 t2 e =
let lazy cc = self.cc in
CC.Theory.merge cc (CC.add_term cc t1) (CC.add_term cc t2) e
let cc_data self ~k n =
let data = N.th_data (CC.find (cc self) n) in
Key_set.find_opt k data
let add_axiom_ self ~keep lits : unit =
Stat.incr self.count_axiom;
match self.msat_acts with
@ -297,6 +377,9 @@ module Make(A : ARG)
CC.set_as_lit cc n (Lit.abs lit);
()
let on_final_check self f = self.on_final_check <- f :: self.on_final_check
let on_partial_check self f = self.on_partial_check <- f :: self.on_partial_check
(* propagation from the bool solver *)
let[@inline] partial_check (self:t) (acts:_ Msat.acts) : unit =
check_ ~final:false self acts
@ -343,11 +426,12 @@ module Make(A : ARG)
on_partial_check=[];
on_final_check=[];
on_cc_merge=IM.empty;
on_cc_new_term=IM.empty;
on_cc_new_term=[];
} in
ignore (Lazy.force @@ self.cc : CC.t);
let lazy cc = self.cc in
CC.on_merge cc (handle_on_cc_merge self);
CC.on_new_term cc (handle_on_cc_new_term self);
self
end
type conflict = lit list
@ -360,12 +444,15 @@ module Make(A : ARG)
module Atom = Sat_solver.Atom
module Proof = Sat_solver.Proof
type proof = Proof.t
(* main solver state *)
type t = {
si: Solver_internal.t;
solver: Sat_solver.t;
stat: Stat.t;
count_clause: int Stat.counter;
count_solve: int Stat.counter;
(* config: Config.t *)
}
type solver = t
@ -380,6 +467,20 @@ module Make(A : ARG)
type theory = (module THEORY)
let mk_theory (type st)
~name ~create_and_setup
?(push_level=fun _ -> ())
?(pop_levels=fun _ _ -> ())
() : theory =
let module Th = struct
type t = st
let name = name
let create_and_setup = create_and_setup
let push_level = push_level
let pop_levels = pop_levels
end in
(module Th : THEORY)
(** {2 Main} *)
let add_theory (self:t) (th:theory) : unit =
@ -409,6 +510,8 @@ module Make(A : ARG)
si;
solver=Sat_solver.create ?store_proof ?size si;
stat;
count_clause=Stat.mk_int stat "solver-clauses";
count_solve=Stat.mk_int stat "solver-solve";
} in
add_theory_l self theories;
(* assert [true] and [not false] *)
@ -435,6 +538,15 @@ module Make(A : ARG)
let lit = Lit.atom (tst self) ?sign t in
mk_atom_lit self lit
let add_clause self c = Sat_solver.add_clause_a self.solver (c:_ IArray.t:>_ array) A.Proof.default
let add_clause_l self l = add_clause self (IArray.of_list l)
let add_clause_lits self l =
add_clause self @@ IArray.map (mk_atom_lit self) l
let add_clause_lits_l self l =
add_clause self @@ IArray.of_list_map (mk_atom_lit self) l
(** {2 Result} *)
module Unknown = struct
@ -464,7 +576,7 @@ module Make(A : ARG)
let pp_model = Model.pp
*)
type res = (Model.t, Proof.t, lit IArray.t, Unknown.t) Sidekick_core.solver_res
type res = (Model.t, Proof.t, unit ->lit IArray.t, Unknown.t) Sidekick_core.solver_res
(** {2 Main} *)
@ -482,15 +594,15 @@ module Make(A : ARG)
(* map boolean subterms to literals *)
let add_bool_subterms_ (self:t) (t:T.t) : unit =
Term.iter_dag t
|> Iter.filter (fun t -> Ty.is_prop @@ Term.ty t)
T.iter_dag t
|> Iter.filter (fun t -> Ty.is_bool @@ T.ty t)
|> Iter.filter
(fun t -> match Term.view t with
| Term.Not _ -> false (* will process the subterm just later *)
(fun t -> match CC_A.cc_view t with
| CC_view.Not _ -> false (* will process the subterm just later *)
| _ -> true)
|> Iter.iter
(fun sub ->
Log.debugf 5 (fun k->k "(@[solver.map-to-lit@ :subterm %a@])" Term.pp sub);
Log.debugf 5 (fun k->k "(@[solver.map-to-lit@ :subterm %a@])" T.pp sub);
ignore (mk_atom_t self sub : Sat_solver.atom))
let assume (self:t) (c:Lit.t IArray.t) : unit =
@ -498,7 +610,7 @@ module Make(A : ARG)
IArray.iter (fun lit -> add_bool_subterms_ self @@ Lit.term lit) c;
let c = IArray.to_array_map (Sat_solver.make_atom sat) c in
Stat.incr self.count_clause;
Sat_solver.add_clause_a sat c Proof_default
Sat_solver.add_clause_a sat c A.Proof.default
(* TODO: remove? use a special constant + micro theory instead?
let[@inline] assume_distinct self l ~neq lit : unit =
@ -512,8 +624,6 @@ module Make(A : ARG)
*)
()
(* TODO: main loop with iterative deepening of the unrolling limit
(not the value depth limit) *)
let solve ?(on_exit=[]) ?(check=true) ~assumptions (self:t) : res =
let do_on_exit () =
List.iter (fun f->f()) on_exit;
@ -523,17 +633,15 @@ module Make(A : ARG)
match r with
| Sat_solver.Sat st ->
Log.debugf 1 (fun k->k "SAT");
let lits f = st.iter_trail f (fun _ -> ()) in
let m = Theory_combine.mk_model (th_combine self) lits in
let _lits f = st.iter_trail f (fun _ -> ()) in
let m = Model.empty in
(* TODO Theory_combine.mk_model (th_combine self) lits *)
do_on_exit ();
Sat m
(*
let env = Ast.env_empty in
let m = Model.make ~env in
Unknown U_incomplete (* TODO *)
*)
| Sat_solver.Unsat us ->
let uc () =
clause_of_mclause @@ us.Msat.unsat_conflict ()
in
let pr =
try
let pr = us.get_proof () in
@ -542,6 +650,6 @@ module Make(A : ARG)
with Msat.Solver_intf.No_proof -> None
in
do_on_exit ();
Unsat pr
Unsat {proof=pr; unsat_core=uc}
end