mirror of
https://github.com/c-cube/sidekick.git
synced 2026-01-29 04:44:52 -05:00
refactor: simplify literals; remove useless casts in CC; bit for pending nodes
This commit is contained in:
parent
b12db3f03e
commit
9ac274fc09
10 changed files with 68 additions and 107 deletions
|
|
@ -728,7 +728,6 @@ module Make (Th : Theory_intf.S) = struct
|
||||||
end
|
end
|
||||||
|
|
||||||
let[@inline] theory st = Lazy.force st.th
|
let[@inline] theory st = Lazy.force st.th
|
||||||
|
|
||||||
let[@inline] nb_clauses st = Vec.size st.clauses
|
let[@inline] nb_clauses st = Vec.size st.clauses
|
||||||
let[@inline] decision_level st = Vec.size st.elt_levels
|
let[@inline] decision_level st = Vec.size st.elt_levels
|
||||||
let[@inline] base_level st = if decision_level st > 0 then 1 else 0 (* first level=assumptions *)
|
let[@inline] base_level st = if decision_level st > 0 then 1 else 0 (* first level=assumptions *)
|
||||||
|
|
|
||||||
|
|
@ -72,8 +72,8 @@ let[@inline] on_backtrack cc f : unit =
|
||||||
let[@inline] is_root_ (n:node) : bool = n.n_root == n
|
let[@inline] is_root_ (n:node) : bool = n.n_root == n
|
||||||
|
|
||||||
let[@inline] size_ (r:repr) =
|
let[@inline] size_ (r:repr) =
|
||||||
assert (is_root_ (r:>node));
|
assert (is_root_ r);
|
||||||
Bag.size (r :> node).n_parents
|
Bag.size r.n_parents
|
||||||
|
|
||||||
(* check if [t] is in the congruence closure.
|
(* check if [t] is in the congruence closure.
|
||||||
Invariant: [in_cc t => in_cc u, forall u subterm t] *)
|
Invariant: [in_cc t => in_cc u, forall u subterm t] *)
|
||||||
|
|
@ -90,9 +90,9 @@ let rec find_rec cc (n:node) : repr =
|
||||||
let old_root = n.n_root in
|
let old_root = n.n_root in
|
||||||
let root = find_rec cc old_root in
|
let root = find_rec cc old_root in
|
||||||
(* path compression *)
|
(* path compression *)
|
||||||
if (root :> node) != old_root then (
|
if root != old_root then (
|
||||||
on_backtrack cc (fun () -> n.n_root <- old_root);
|
on_backtrack cc (fun () -> n.n_root <- old_root);
|
||||||
n.n_root <- (root :> node);
|
n.n_root <- root;
|
||||||
);
|
);
|
||||||
root
|
root
|
||||||
)
|
)
|
||||||
|
|
@ -154,8 +154,11 @@ let is_done (cc:t): bool =
|
||||||
Vec.is_empty cc.combine
|
Vec.is_empty cc.combine
|
||||||
|
|
||||||
let push_pending cc t : unit =
|
let push_pending cc t : unit =
|
||||||
|
if not @@ Equiv_class.get_field Equiv_class.field_is_pending t then (
|
||||||
Log.debugf 5 (fun k->k "(@[<hv1>cc.push_pending@ %a@])" Equiv_class.pp t);
|
Log.debugf 5 (fun k->k "(@[<hv1>cc.push_pending@ %a@])" Equiv_class.pp t);
|
||||||
|
Equiv_class.set_field Equiv_class.field_is_pending true t;
|
||||||
Vec.push cc.pending t
|
Vec.push cc.pending t
|
||||||
|
)
|
||||||
|
|
||||||
let push_combine cc t u e : unit =
|
let push_combine cc t u e : unit =
|
||||||
Log.debugf 5
|
Log.debugf 5
|
||||||
|
|
@ -322,11 +325,13 @@ let rec update_pending (cc:t): unit =
|
||||||
might have changed *)
|
might have changed *)
|
||||||
while not (Vec.is_empty cc.pending) do
|
while not (Vec.is_empty cc.pending) do
|
||||||
let n = Vec.pop_last cc.pending in
|
let n = Vec.pop_last cc.pending in
|
||||||
|
Equiv_class.set_field Equiv_class.field_is_pending false n;
|
||||||
(* check if some parent collided *)
|
(* check if some parent collided *)
|
||||||
begin match find_by_signature cc n.n_term with
|
begin match find_by_signature cc n.n_term with
|
||||||
| None ->
|
| None ->
|
||||||
(* add to the signature table [sig(n) --> n] *)
|
(* add to the signature table [sig(n) --> n] *)
|
||||||
add_signature cc n.n_term n
|
add_signature cc n.n_term n
|
||||||
|
| Some u when n == u -> ()
|
||||||
| Some u ->
|
| Some u ->
|
||||||
(* must combine [t] with [r] *)
|
(* must combine [t] with [r] *)
|
||||||
if not @@ same_class cc n u then (
|
if not @@ same_class cc n u then (
|
||||||
|
|
@ -337,8 +342,7 @@ let rec update_pending (cc:t): unit =
|
||||||
| App_cst (f1, a1), App_cst (f2, a2) ->
|
| App_cst (f1, a1), App_cst (f2, a2) ->
|
||||||
assert (Cst.equal f1 f2);
|
assert (Cst.equal f1 f2);
|
||||||
assert (IArray.length a1 = IArray.length a2);
|
assert (IArray.length a1 = IArray.length a2);
|
||||||
Explanation.mk_merges @@
|
Explanation.mk_merges @@ IArray.map2 (fun u1 u2 -> add_ cc u1, add_ cc u2) a1 a2
|
||||||
IArray.map2 (fun u1 u2 -> add_ cc u1, add_ cc u2) a1 a2
|
|
||||||
| If _, _ | App_cst _, _ | Bool _, _
|
| If _, _ | App_cst _, _ | Bool _, _
|
||||||
-> assert false
|
-> assert false
|
||||||
in
|
in
|
||||||
|
|
@ -386,14 +390,12 @@ and update_combine cc =
|
||||||
in
|
in
|
||||||
(* remove [ra.parents] from signature, put them into [st.pending] *)
|
(* remove [ra.parents] from signature, put them into [st.pending] *)
|
||||||
begin
|
begin
|
||||||
Bag.to_seq (r_from:>node).n_parents
|
Bag.to_seq r_from.n_parents
|
||||||
|> Sequence.iter
|
|> Sequence.iter
|
||||||
(fun parent -> push_pending cc parent)
|
(fun parent -> push_pending cc parent)
|
||||||
end;
|
end;
|
||||||
(* perform [union ra rb] *)
|
(* perform [union ra rb] *)
|
||||||
begin
|
begin
|
||||||
let r_from = (r_from :> node) in
|
|
||||||
let r_into = (r_into :> node) in
|
|
||||||
let r_into_old_parents = r_into.n_parents in
|
let r_into_old_parents = r_into.n_parents in
|
||||||
let r_into_old_tags = r_into.n_tags in
|
let r_into_old_tags = r_into.n_tags in
|
||||||
on_backtrack cc
|
on_backtrack cc
|
||||||
|
|
@ -485,14 +487,14 @@ let add_seq cc seq =
|
||||||
|
|
||||||
(* to do after backtracking: reset task lists *)
|
(* to do after backtracking: reset task lists *)
|
||||||
let reset_tasks cc : unit =
|
let reset_tasks cc : unit =
|
||||||
|
Vec.iter (Equiv_class.set_field Equiv_class.field_is_pending false) cc.pending;
|
||||||
Vec.clear cc.pending;
|
Vec.clear cc.pending;
|
||||||
Vec.clear cc.combine;
|
Vec.clear cc.combine;
|
||||||
()
|
()
|
||||||
|
|
||||||
(* assert that this boolean literal holds *)
|
(* assert that this boolean literal holds *)
|
||||||
let assert_lit cc lit : unit = match Lit.view lit with
|
let assert_lit cc lit : unit =
|
||||||
| Lit_fresh _ -> ()
|
let t = Lit.view lit in
|
||||||
| Lit_atom t ->
|
|
||||||
assert (Ty.is_prop t.term_ty);
|
assert (Ty.is_prop t.term_ty);
|
||||||
Log.debugf 5 (fun k->k "(@[cc.assert_lit@ %a@])" Lit.pp lit);
|
Log.debugf 5 (fun k->k "(@[cc.assert_lit@ %a@])" Lit.pp lit);
|
||||||
let sign = Lit.sign lit in
|
let sign = Lit.sign lit in
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ type t = cc_node
|
||||||
type payload = cc_node_payload = ..
|
type payload = cc_node_payload = ..
|
||||||
|
|
||||||
let field_is_active = Node_bits.mk_field()
|
let field_is_active = Node_bits.mk_field()
|
||||||
|
let field_is_pending = Node_bits.mk_field()
|
||||||
let () = Node_bits.freeze()
|
let () = Node_bits.freeze()
|
||||||
|
|
||||||
let[@inline] equal (n1:t) n2 = n1==n2
|
let[@inline] equal (n1:t) n2 = n1==n2
|
||||||
|
|
@ -59,6 +60,9 @@ let payload_pred ~f:p n =
|
||||||
| l -> List.exists p l
|
| l -> List.exists p l
|
||||||
end
|
end
|
||||||
|
|
||||||
|
let[@inline] get_field f t = Node_bits.get f t.n_bits
|
||||||
|
let[@inline] set_field f b t = t.n_bits <- Node_bits.set f b t.n_bits
|
||||||
|
|
||||||
module Tbl = CCHashtbl.Make(struct
|
module Tbl = CCHashtbl.Make(struct
|
||||||
type t = cc_node
|
type t = cc_node
|
||||||
let equal = equal
|
let equal = equal
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,9 @@ val field_is_active : Node_bits.field
|
||||||
(** The term is needed for evaluation. We must try to evaluate it
|
(** The term is needed for evaluation. We must try to evaluate it
|
||||||
or to find a value for it using the theory *)
|
or to find a value for it using the theory *)
|
||||||
|
|
||||||
|
val field_is_pending : Node_bits.field
|
||||||
|
(** true iff the node is in the [cc.pending] queue *)
|
||||||
|
|
||||||
(** {2 basics} *)
|
(** {2 basics} *)
|
||||||
|
|
||||||
val term : t -> term
|
val term : t -> term
|
||||||
|
|
@ -49,6 +52,9 @@ val set_payload : ?can_erase:(payload -> bool) -> t -> payload -> unit
|
||||||
@param can_erase if provided, checks whether an existing value
|
@param can_erase if provided, checks whether an existing value
|
||||||
is to be replaced instead of adding a new entry *)
|
is to be replaced instead of adding a new entry *)
|
||||||
|
|
||||||
|
val get_field : Node_bits.field -> t -> bool
|
||||||
|
val set_field : Node_bits.field -> bool -> t -> unit
|
||||||
|
|
||||||
module Tbl : CCHashtbl.S with type key = t
|
module Tbl : CCHashtbl.S with type key = t
|
||||||
|
|
||||||
(**/**)
|
(**/**)
|
||||||
|
|
|
||||||
|
|
@ -2,48 +2,31 @@
|
||||||
open Solver_types
|
open Solver_types
|
||||||
|
|
||||||
type t = lit = {
|
type t = lit = {
|
||||||
lit_view : lit_view;
|
lit_term: term;
|
||||||
lit_sign : bool
|
lit_sign : bool
|
||||||
}
|
}
|
||||||
|
|
||||||
and view = lit_view =
|
|
||||||
| Lit_fresh of ID.t
|
|
||||||
| Lit_atom of term
|
|
||||||
|
|
||||||
let neg l = {l with lit_sign=not l.lit_sign}
|
let neg l = {l with lit_sign=not l.lit_sign}
|
||||||
|
|
||||||
let sign t = t.lit_sign
|
let[@inline] sign t = t.lit_sign
|
||||||
let view (t:t): lit_view = t.lit_view
|
let[@inline] view (t:t): term = t.lit_term
|
||||||
|
|
||||||
let abs t: t = {t with lit_sign=true}
|
let[@inline] abs t: t = {t with lit_sign=true}
|
||||||
|
|
||||||
let make ~sign v = {lit_sign=sign; lit_view=v}
|
let make ~sign t = {lit_sign=sign; lit_term=t}
|
||||||
|
|
||||||
(* assume the ID is fresh *)
|
let dummy = make ~sign:true Term.dummy
|
||||||
let fresh_with id = make ~sign:true (Lit_fresh id)
|
|
||||||
|
|
||||||
(* fresh boolean literal *)
|
|
||||||
let fresh: unit -> t =
|
|
||||||
let n = ref 0 in
|
|
||||||
fun () ->
|
|
||||||
let id = ID.makef "#fresh_%d" !n in
|
|
||||||
incr n;
|
|
||||||
make ~sign:true (Lit_fresh id)
|
|
||||||
|
|
||||||
let dummy = fresh()
|
|
||||||
|
|
||||||
let atom ?(sign=true) (t:term) : t =
|
let atom ?(sign=true) (t:term) : t =
|
||||||
let t, sign' = Term.abs t in
|
let t, sign' = Term.abs t in
|
||||||
let sign = if not sign' then not sign else sign in
|
let sign = if not sign' then not sign else sign in
|
||||||
make ~sign (Lit_atom t)
|
make ~sign t
|
||||||
|
|
||||||
let as_atom (lit:t) : (term * bool) option = match lit.lit_view with
|
let[@inline] as_atom (lit:t) = lit.lit_term, lit.lit_sign
|
||||||
| Lit_atom t -> Some (t, lit.lit_sign)
|
|
||||||
| _ -> None
|
|
||||||
|
|
||||||
let hash = hash_lit
|
let hash = hash_lit
|
||||||
let compare = cmp_lit
|
let compare = cmp_lit
|
||||||
let equal a b = compare a b = 0
|
let[@inline] equal a b = compare a b = 0
|
||||||
let pp = pp_lit
|
let pp = pp_lit
|
||||||
let print = pp
|
let print = pp
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,21 +3,15 @@
|
||||||
open Solver_types
|
open Solver_types
|
||||||
|
|
||||||
type t = lit = {
|
type t = lit = {
|
||||||
lit_view : lit_view;
|
lit_term: term;
|
||||||
lit_sign : bool
|
lit_sign : bool
|
||||||
}
|
}
|
||||||
|
|
||||||
and view = lit_view =
|
|
||||||
| Lit_fresh of ID.t
|
|
||||||
| Lit_atom of term
|
|
||||||
|
|
||||||
val neg : t -> t
|
val neg : t -> t
|
||||||
val abs : t -> t
|
val abs : t -> t
|
||||||
val sign : t -> bool
|
val sign : t -> bool
|
||||||
val view : t -> lit_view
|
val view : t -> term
|
||||||
val as_atom : t -> (term * bool) option
|
val as_atom : t -> term * bool
|
||||||
val fresh_with : ID.t -> t
|
|
||||||
val fresh : unit -> t
|
|
||||||
val dummy : t
|
val dummy : t
|
||||||
val atom : ?sign:bool -> term -> t
|
val atom : ?sign:bool -> term -> t
|
||||||
val hash : t -> int
|
val hash : t -> int
|
||||||
|
|
|
||||||
|
|
@ -55,14 +55,10 @@ and explanation =
|
||||||
|
|
||||||
(* boolean literal *)
|
(* boolean literal *)
|
||||||
and lit = {
|
and lit = {
|
||||||
lit_view: lit_view;
|
lit_term: term;
|
||||||
lit_sign: bool;
|
lit_sign: bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
and lit_view =
|
|
||||||
| Lit_fresh of ID.t (* fresh literals *)
|
|
||||||
| Lit_atom of term
|
|
||||||
|
|
||||||
and cst = {
|
and cst = {
|
||||||
cst_id: ID.t;
|
cst_id: ID.t;
|
||||||
cst_view: cst_view;
|
cst_view: cst_view;
|
||||||
|
|
@ -143,25 +139,13 @@ let[@inline] term_cmp_ a b = CCInt.compare a.term_id b.term_id
|
||||||
let cmp_lit a b =
|
let cmp_lit a b =
|
||||||
let c = CCBool.compare a.lit_sign b.lit_sign in
|
let c = CCBool.compare a.lit_sign b.lit_sign in
|
||||||
if c<>0 then c
|
if c<>0 then c
|
||||||
else (
|
else term_cmp_ a.lit_term b.lit_term
|
||||||
let int_of_cell_ = function
|
|
||||||
| Lit_fresh _ -> 0
|
|
||||||
| Lit_atom _ -> 1
|
|
||||||
in
|
|
||||||
match a.lit_view, b.lit_view with
|
|
||||||
| Lit_fresh i1, Lit_fresh i2 -> ID.compare i1 i2
|
|
||||||
| Lit_atom t1, Lit_atom t2 -> term_cmp_ t1 t2
|
|
||||||
| Lit_fresh _, _ | Lit_atom _, _
|
|
||||||
-> CCInt.compare (int_of_cell_ a.lit_view) (int_of_cell_ b.lit_view)
|
|
||||||
)
|
|
||||||
|
|
||||||
let cst_compare a b = ID.compare a.cst_id b.cst_id
|
let cst_compare a b = ID.compare a.cst_id b.cst_id
|
||||||
|
|
||||||
let hash_lit a =
|
let hash_lit a =
|
||||||
let sign = a.lit_sign in
|
let sign = a.lit_sign in
|
||||||
match a.lit_view with
|
Hash.combine3 2 (Hash.bool sign) (term_hash_ a.lit_term)
|
||||||
| Lit_fresh i -> Hash.combine3 1 (Hash.bool sign) (ID.hash i)
|
|
||||||
| Lit_atom t -> Hash.combine3 2 (Hash.bool sign) (term_hash_ t)
|
|
||||||
|
|
||||||
let cmp_cc_node a b = term_cmp_ a.n_term b.n_term
|
let cmp_cc_node a b = term_cmp_ a.n_term b.n_term
|
||||||
|
|
||||||
|
|
@ -236,12 +220,8 @@ let pp_term = pp_term_top ~ids:false
|
||||||
let pp_term_view = pp_term_view ~pp_id:ID.pp_name ~pp_t:pp_term
|
let pp_term_view = pp_term_view ~pp_id:ID.pp_name ~pp_t:pp_term
|
||||||
|
|
||||||
let pp_lit out l =
|
let pp_lit out l =
|
||||||
let pp_lit_view out = function
|
if l.lit_sign then pp_term out l.lit_term
|
||||||
| Lit_fresh i -> Format.fprintf out "#%a" ID.pp i
|
else Format.fprintf out "(@[@<1>¬@ %a@])" pp_term l.lit_term
|
||||||
| Lit_atom t -> pp_term out t
|
|
||||||
in
|
|
||||||
if l.lit_sign then pp_lit_view out l.lit_view
|
|
||||||
else Format.fprintf out "(@[@<1>¬@ %a@])" pp_lit_view l.lit_view
|
|
||||||
|
|
||||||
let pp_cc_node out n = pp_term out n.n_term
|
let pp_cc_node out n = pp_term out n.n_term
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -48,10 +48,9 @@ let assume_lit (self:t) (lit:Lit.t) : unit =
|
||||||
(fun k->k "(@[<1>@{<green>th_combine.assume_lit@}@ @[%a@]@])" Lit.pp lit);
|
(fun k->k "(@[<1>@{<green>th_combine.assume_lit@}@ @[%a@]@])" Lit.pp lit);
|
||||||
(* check consistency first *)
|
(* check consistency first *)
|
||||||
begin match Lit.view lit with
|
begin match Lit.view lit with
|
||||||
| Lit_fresh _ -> ()
|
| {term_view=Bool true; _} -> ()
|
||||||
| Lit_atom {term_view=Bool true; _} -> ()
|
| {term_view=Bool false; _} -> ()
|
||||||
| Lit_atom {term_view=Bool false; _} -> ()
|
| _ ->
|
||||||
| Lit_atom _ ->
|
|
||||||
(* transmit to theories. *)
|
(* transmit to theories. *)
|
||||||
C_clos.assert_lit (cc self) lit;
|
C_clos.assert_lit (cc self) lit;
|
||||||
theories self (fun (module Th) -> Th.on_assert Th.state lit);
|
theories self (fun (module Th) -> Th.on_assert Th.state lit);
|
||||||
|
|
@ -98,11 +97,9 @@ let assume_real (self:t) (slice:Lit.t Sat_solver.slice_actions) =
|
||||||
)
|
)
|
||||||
|
|
||||||
let add_formula (self:t) (lit:Lit.t) =
|
let add_formula (self:t) (lit:Lit.t) =
|
||||||
match Lit.view lit with
|
let t = Lit.view lit in
|
||||||
| Lit_atom t ->
|
|
||||||
let lazy cc = self.cc in
|
let lazy cc = self.cc in
|
||||||
ignore (C_clos.add cc t : cc_node)
|
ignore (C_clos.add cc t : cc_node)
|
||||||
| Lit_fresh _ -> ()
|
|
||||||
|
|
||||||
(* propagation from the bool solver *)
|
(* propagation from the bool solver *)
|
||||||
let assume (self:t) (slice:_ Sat_solver.slice_actions) =
|
let assume (self:t) (slice:_ Sat_solver.slice_actions) =
|
||||||
|
|
|
||||||
|
|
@ -236,13 +236,11 @@ let tseitin (self:t) (lit:Lit.t) (lit_t:term) (v:term view) : unit =
|
||||||
)
|
)
|
||||||
|
|
||||||
let on_assert (self:t) (lit:Lit.t) =
|
let on_assert (self:t) (lit:Lit.t) =
|
||||||
match Lit.view lit with
|
let t = Lit.view lit in
|
||||||
| Lit.Lit_atom t ->
|
|
||||||
begin match view t with
|
begin match view t with
|
||||||
| B_atom _ -> ()
|
| B_atom _ -> ()
|
||||||
| v -> tseitin self lit t v
|
| v -> tseitin self lit t v
|
||||||
end
|
end
|
||||||
| _ -> ()
|
|
||||||
|
|
||||||
let final_check _ _ : unit = ()
|
let final_check _ _ : unit = ()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -228,10 +228,8 @@ let check_smt_model (solver:Solver.Sat_solver.t) (hyps:_ Vec.t) (m:Model.t) : un
|
||||||
let is_true = S.Atom.is_true a in
|
let is_true = S.Atom.is_true a in
|
||||||
let is_false = S.Atom.is_true (S.Atom.neg a) in
|
let is_false = S.Atom.is_true (S.Atom.neg a) in
|
||||||
let sat_value = if is_true then Some true else if is_false then Some false else None in
|
let sat_value = if is_true then Some true else if is_false then Some false else None in
|
||||||
begin match Lit.as_atom lit with
|
let t, sign = Lit.as_atom lit in
|
||||||
| None -> assert false
|
begin match Model.eval m t with
|
||||||
| Some (t, sign) ->
|
|
||||||
match Model.eval m t with
|
|
||||||
| Some (V_bool b) ->
|
| Some (V_bool b) ->
|
||||||
let b = if sign then b else not b in
|
let b = if sign then b else not b in
|
||||||
if (is_true || is_false) && ((b && is_false) || (not b && is_true)) then (
|
if (is_true || is_false) && ((b && is_false) || (not b && is_true)) then (
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue