mirror of
https://github.com/c-cube/sidekick.git
synced 2026-01-28 12:24:50 -05:00
refactor(cc): internal refactorings
This commit is contained in:
parent
b8445d0ca3
commit
1d212350ef
1 changed files with 53 additions and 48 deletions
|
|
@ -3,15 +3,17 @@ module Vec = Sidekick_sat.Vec
|
||||||
module Log = Sidekick_sat.Log
|
module Log = Sidekick_sat.Log
|
||||||
open Solver_types
|
open Solver_types
|
||||||
|
|
||||||
type node = Equiv_class.t
|
module N = Equiv_class
|
||||||
type repr = Equiv_class.t
|
|
||||||
|
type node = N.t
|
||||||
|
type repr = N.t
|
||||||
type conflict = Theory.conflict
|
type conflict = Theory.conflict
|
||||||
|
|
||||||
(** A signature is a shallow term shape where immediate subterms
|
(** A signature is a shallow term shape where immediate subterms
|
||||||
are representative *)
|
are representative *)
|
||||||
module Signature = struct
|
module Signature = struct
|
||||||
type t = node Term.view
|
type t = node Term.view
|
||||||
include Term_cell.Make_eq(Equiv_class)
|
include Term_cell.Make_eq(N)
|
||||||
end
|
end
|
||||||
|
|
||||||
module Sig_tbl = CCHashtbl.Make(Signature)
|
module Sig_tbl = CCHashtbl.Make(Signature)
|
||||||
|
|
@ -51,6 +53,7 @@ type t = {
|
||||||
have the same signature *)
|
have the same signature *)
|
||||||
tasks: task Vec.t;
|
tasks: task Vec.t;
|
||||||
(* tasks to perform *)
|
(* tasks to perform *)
|
||||||
|
on_backtrack:(unit->unit)->unit;
|
||||||
mutable ps_lits: Lit.Set.t;
|
mutable ps_lits: Lit.Set.t;
|
||||||
(* proof state *)
|
(* proof state *)
|
||||||
ps_queue: (node*node) Vec.t;
|
ps_queue: (node*node) Vec.t;
|
||||||
|
|
@ -64,10 +67,7 @@ type t = {
|
||||||
several times.
|
several times.
|
||||||
See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *)
|
See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *)
|
||||||
|
|
||||||
let[@inline] on_backtrack cc f : unit =
|
let[@inline] on_backtrack cc f : unit = cc.on_backtrack f
|
||||||
let (module A) = cc.acts in
|
|
||||||
A.on_backtrack f
|
|
||||||
|
|
||||||
let[@inline] is_root_ (n:node) : bool = n.n_root == n
|
let[@inline] is_root_ (n:node) : bool = n.n_root == n
|
||||||
|
|
||||||
let[@inline] size_ (r:repr) =
|
let[@inline] size_ (r:repr) =
|
||||||
|
|
@ -112,7 +112,7 @@ let[@inline] find st (n:node) : repr =
|
||||||
let[@inline] find_tn cc (t:term) : repr = get_ cc t |> find cc
|
let[@inline] find_tn cc (t:term) : repr = get_ cc t |> find cc
|
||||||
|
|
||||||
let[@inline] same_class cc (n1:node)(n2:node): bool =
|
let[@inline] same_class cc (n1:node)(n2:node): bool =
|
||||||
Equiv_class.equal (find cc n1) (find cc n2)
|
N.equal (find cc n1) (find cc n2)
|
||||||
|
|
||||||
(* compute signature *)
|
(* compute signature *)
|
||||||
let signature cc (t:term): node Term.view option =
|
let signature cc (t:term): node Term.view option =
|
||||||
|
|
@ -120,7 +120,7 @@ let signature cc (t:term): node Term.view option =
|
||||||
begin match Term.view t with
|
begin match Term.view t with
|
||||||
| App_cst (_, a) when IArray.is_empty a -> None
|
| App_cst (_, a) when IArray.is_empty a -> None
|
||||||
| App_cst (c, _) when not @@ Cst.do_cc c -> None (* no CC *)
|
| App_cst (c, _) when not @@ Cst.do_cc c -> None (* no CC *)
|
||||||
| App_cst (f, a) -> App_cst (f, IArray.map find a) |> CCOpt.return (* FIXME: relevance *)
|
| App_cst (f, a) -> Some (App_cst (f, IArray.map find a)) (* FIXME: relevance? *)
|
||||||
| Bool _ | If _
|
| Bool _ | If _
|
||||||
-> None (* no congruence for these *)
|
-> None (* no congruence for these *)
|
||||||
end
|
end
|
||||||
|
|
@ -146,16 +146,16 @@ let add_signature cc (t:term) (r:node): unit =
|
||||||
end
|
end
|
||||||
|
|
||||||
let push_pending cc t : unit =
|
let push_pending cc t : unit =
|
||||||
if not @@ Equiv_class.get_field Equiv_class.field_is_pending t then (
|
if not @@ N.get_field N.field_is_pending t then (
|
||||||
Log.debugf 5 (fun k->k "(@[<hv1>cc.push_pending@ %a@])" Equiv_class.pp t);
|
Log.debugf 5 (fun k->k "(@[<hv1>cc.push_pending@ %a@])" N.pp t);
|
||||||
Equiv_class.set_field Equiv_class.field_is_pending true t;
|
N.set_field N.field_is_pending true t;
|
||||||
Vec.push cc.tasks (T_pending t)
|
Vec.push cc.tasks (T_pending t)
|
||||||
)
|
)
|
||||||
|
|
||||||
let push_combine cc t u e : unit =
|
let push_combine cc t u e : unit =
|
||||||
Log.debugf 5
|
Log.debugf 5
|
||||||
(fun k->k "(@[<hv1>cc.push_combine@ :t1 %a@ :t2 %a@ :expl %a@])"
|
(fun k->k "(@[<hv1>cc.push_combine@ :t1 %a@ :t2 %a@ :expl %a@])"
|
||||||
Equiv_class.pp t Equiv_class.pp u Explanation.pp e);
|
N.pp t N.pp u Explanation.pp e);
|
||||||
Vec.push cc.tasks @@ T_merge (t,u,e)
|
Vec.push cc.tasks @@ T_merge (t,u,e)
|
||||||
|
|
||||||
(* re-root the explanation tree of the equivalence class of [n]
|
(* re-root the explanation tree of the equivalence class of [n]
|
||||||
|
|
@ -177,7 +177,7 @@ let raise_conflict (cc:t) (e:conflict): _ =
|
||||||
(* clear tasks queue *)
|
(* clear tasks queue *)
|
||||||
Vec.iter
|
Vec.iter
|
||||||
(function
|
(function
|
||||||
| T_pending n -> Equiv_class.set_field Equiv_class.field_is_pending false n
|
| T_pending n -> N.set_field N.field_is_pending false n
|
||||||
| T_merge _ -> ())
|
| T_merge _ -> ())
|
||||||
cc.tasks;
|
cc.tasks;
|
||||||
Vec.clear cc.tasks;
|
Vec.clear cc.tasks;
|
||||||
|
|
@ -259,8 +259,8 @@ let explain_loop (cc : t) : Lit.Set.t =
|
||||||
while not (Vec.is_empty cc.ps_queue) do
|
while not (Vec.is_empty cc.ps_queue) do
|
||||||
let a, b = Vec.pop_last cc.ps_queue in
|
let a, b = Vec.pop_last cc.ps_queue in
|
||||||
Log.debugf 5
|
Log.debugf 5
|
||||||
(fun k->k "(@[cc.explain_loop at@ %a@ %a@])" Equiv_class.pp a Equiv_class.pp b);
|
(fun k->k "(@[cc.explain_loop at@ %a@ %a@])" N.pp a N.pp b);
|
||||||
assert (Equiv_class.equal (find cc a) (find cc b));
|
assert (N.equal (find cc a) (find cc b));
|
||||||
let c = find_common_ancestor a b in
|
let c = find_common_ancestor a b in
|
||||||
explain_along_path cc a c;
|
explain_along_path cc a c;
|
||||||
explain_along_path cc b c;
|
explain_along_path cc b c;
|
||||||
|
|
@ -296,6 +296,17 @@ let add_tag_n cc (n:node) (tag:int) (expl:explanation) : unit =
|
||||||
n.n_tags <- Util.Int_map.add tag (n,expl) n.n_tags;
|
n.n_tags <- Util.Int_map.add tag (n,expl) n.n_tags;
|
||||||
)
|
)
|
||||||
|
|
||||||
|
let relevant_subterms (t:Term.t) : Term.t Sequence.t =
|
||||||
|
fun yield ->
|
||||||
|
match t.term_view with
|
||||||
|
| App_cst (c, a) when Cst.do_cc c -> IArray.iter yield a
|
||||||
|
| Bool _ | App_cst _ -> ()
|
||||||
|
| If (a,b,c) ->
|
||||||
|
(* TODO: relevancy? only [a] needs be decided for now *)
|
||||||
|
yield a;
|
||||||
|
yield b;
|
||||||
|
yield c
|
||||||
|
|
||||||
(* main CC algo: add terms from [pending] to the signature table,
|
(* main CC algo: add terms from [pending] to the signature table,
|
||||||
check for collisions *)
|
check for collisions *)
|
||||||
let rec update_tasks (cc:t): unit =
|
let rec update_tasks (cc:t): unit =
|
||||||
|
|
@ -309,7 +320,7 @@ let rec update_tasks (cc:t): unit =
|
||||||
done
|
done
|
||||||
|
|
||||||
and task_pending_ cc n =
|
and task_pending_ cc n =
|
||||||
Equiv_class.set_field Equiv_class.field_is_pending false n;
|
N.set_field N.field_is_pending false n;
|
||||||
(* check if some parent collided *)
|
(* check if some parent collided *)
|
||||||
begin match find_by_signature cc n.n_term with
|
begin match find_by_signature cc n.n_term with
|
||||||
| None ->
|
| None ->
|
||||||
|
|
@ -343,7 +354,7 @@ and task_pending_ cc n =
|
||||||
and task_merge_ cc a b e_ab : unit =
|
and task_merge_ cc a b e_ab : unit =
|
||||||
let ra = find cc a in
|
let ra = find cc a in
|
||||||
let rb = find cc b in
|
let rb = find cc b in
|
||||||
if not (Equiv_class.equal ra rb) then (
|
if not (N.equal ra rb) then (
|
||||||
assert (is_root_ ra);
|
assert (is_root_ ra);
|
||||||
assert (is_root_ rb);
|
assert (is_root_ rb);
|
||||||
(* We will merge [r_from] into [r_into].
|
(* We will merge [r_from] into [r_into].
|
||||||
|
|
@ -357,8 +368,8 @@ and task_merge_ cc a b e_ab : unit =
|
||||||
Log.debugf 5
|
Log.debugf 5
|
||||||
(fun k->k "(@[<hv>cc.merge.distinct_conflict@ :tag %d@ \
|
(fun k->k "(@[<hv>cc.merge.distinct_conflict@ :tag %d@ \
|
||||||
@[:r1 %a@ :e1 %a@]@ @[:r2 %a@ :e2 %a@]@ :e_ab %a@])"
|
@[:r1 %a@ :e1 %a@]@ @[:r2 %a@ :e2 %a@]@ :e_ab %a@])"
|
||||||
_i Equiv_class.pp n1 Explanation.pp e1
|
_i N.pp n1 Explanation.pp e1
|
||||||
Equiv_class.pp n2 Explanation.pp e2 Explanation.pp e_ab);
|
N.pp n2 Explanation.pp e2 Explanation.pp e_ab);
|
||||||
let lits = explain_unfold cc e1 in
|
let lits = explain_unfold cc e1 in
|
||||||
let lits = explain_unfold ~init:lits cc e2 in
|
let lits = explain_unfold ~init:lits cc e2 in
|
||||||
let lits = explain_unfold ~init:lits cc e_ab in
|
let lits = explain_unfold ~init:lits cc e_ab in
|
||||||
|
|
@ -373,7 +384,7 @@ and task_merge_ cc a b e_ab : unit =
|
||||||
(fun parent -> push_pending cc parent)
|
(fun parent -> push_pending cc parent)
|
||||||
end;
|
end;
|
||||||
(* perform [union ra rb] *)
|
(* perform [union ra rb] *)
|
||||||
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" Equiv_class.pp r_from Equiv_class.pp r_into);
|
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
|
||||||
begin
|
begin
|
||||||
let r_into_old_parents = r_into.n_parents in
|
let r_into_old_parents = r_into.n_parents in
|
||||||
let r_into_old_tags = r_into.n_tags in
|
let r_into_old_tags = r_into.n_tags in
|
||||||
|
|
@ -412,7 +423,7 @@ and notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit =
|
||||||
and add_new_term_ cc (t:term) : node =
|
and add_new_term_ cc (t:term) : node =
|
||||||
assert (not @@ mem cc t);
|
assert (not @@ mem cc t);
|
||||||
Log.debugf 15 (fun k->k "(@[cc.add-term@ %a@])" Term.pp t);
|
Log.debugf 15 (fun k->k "(@[cc.add-term@ %a@])" Term.pp t);
|
||||||
let n = Equiv_class.make t in
|
let n = N.make t in
|
||||||
(* how to add a subterm *)
|
(* how to add a subterm *)
|
||||||
let add_to_parents_of_sub_node (sub:node) : unit =
|
let add_to_parents_of_sub_node (sub:node) : unit =
|
||||||
let old_parents = sub.n_parents in
|
let old_parents = sub.n_parents in
|
||||||
|
|
@ -426,15 +437,7 @@ and add_new_term_ cc (t:term) : node =
|
||||||
add_to_parents_of_sub_node n_u
|
add_to_parents_of_sub_node n_u
|
||||||
in
|
in
|
||||||
(* register sub-terms, add [t] to their parent list *)
|
(* register sub-terms, add [t] to their parent list *)
|
||||||
begin match t.term_view with
|
relevant_subterms t add_sub_t;
|
||||||
| App_cst (c, a) when Cst.do_cc c -> IArray.iter add_sub_t a
|
|
||||||
| Bool _ | App_cst _ -> ()
|
|
||||||
| If (a,b,c) ->
|
|
||||||
(* TODO: relevancy? only [a] needs be decided for now *)
|
|
||||||
add_sub_t a;
|
|
||||||
add_sub_t b;
|
|
||||||
add_sub_t c
|
|
||||||
end;
|
|
||||||
(* remove term when we backtrack *)
|
(* remove term when we backtrack *)
|
||||||
on_backtrack cc
|
on_backtrack cc
|
||||||
(fun () ->
|
(fun () ->
|
||||||
|
|
@ -493,7 +496,7 @@ let assert_distinct cc (l:term list) ~neq (lit:Lit.t) : unit =
|
||||||
let l = List.map (fun t -> t, add cc t |> find cc) l in
|
let l = List.map (fun t -> t, add cc t |> find cc) l in
|
||||||
let coll =
|
let coll =
|
||||||
Sequence.diagonal_l l
|
Sequence.diagonal_l l
|
||||||
|> Sequence.find_pred (fun ((_,n1),(_,n2)) -> Equiv_class.equal n1 n2)
|
|> Sequence.find_pred (fun ((_,n1),(_,n2)) -> N.equal n1 n2)
|
||||||
in
|
in
|
||||||
begin match coll with
|
begin match coll with
|
||||||
| Some ((t1,_n1),(t2,_n2)) ->
|
| Some ((t1,_n1),(t2,_n2)) ->
|
||||||
|
|
@ -508,17 +511,19 @@ let assert_distinct cc (l:term list) ~neq (lit:Lit.t) : unit =
|
||||||
end
|
end
|
||||||
|
|
||||||
let create ?(size=2048) ~actions (tst:Term.state) : t =
|
let create ?(size=2048) ~actions (tst:Term.state) : t =
|
||||||
let nd = Equiv_class.dummy in
|
let nd = N.dummy in
|
||||||
|
let (module A : ACTIONS) = actions in
|
||||||
let cc = {
|
let cc = {
|
||||||
tst;
|
tst;
|
||||||
acts=actions;
|
acts=actions;
|
||||||
tbl = Term.Tbl.create size;
|
tbl = Term.Tbl.create size;
|
||||||
signatures_tbl = Sig_tbl.create size;
|
signatures_tbl = Sig_tbl.create size;
|
||||||
tasks=Vec.make_empty (T_pending Equiv_class.dummy);
|
tasks=Vec.make_empty (T_pending N.dummy);
|
||||||
ps_lits=Lit.Set.empty;
|
ps_lits=Lit.Set.empty;
|
||||||
|
on_backtrack=A.on_backtrack;
|
||||||
ps_queue=Vec.make_empty (nd,nd);
|
ps_queue=Vec.make_empty (nd,nd);
|
||||||
true_ = Equiv_class.dummy;
|
true_ = N.dummy;
|
||||||
false_ = Equiv_class.dummy;
|
false_ = N.dummy;
|
||||||
} in
|
} in
|
||||||
cc.true_ <- add cc (Term.true_ tst);
|
cc.true_ <- add cc (Term.true_ tst);
|
||||||
cc.false_ <- add cc (Term.false_ tst);
|
cc.false_ <- add cc (Term.false_ tst);
|
||||||
|
|
@ -531,7 +536,7 @@ let final_check cc : unit =
|
||||||
(* model: map each uninterpreted equiv class to some ID *)
|
(* model: map each uninterpreted equiv class to some ID *)
|
||||||
let mk_model (cc:t) (m:Model.t) : Model.t =
|
let mk_model (cc:t) (m:Model.t) : Model.t =
|
||||||
(* populate [repr -> value] table *)
|
(* populate [repr -> value] table *)
|
||||||
let t_tbl = Equiv_class.Tbl.create 32 in
|
let t_tbl = N.Tbl.create 32 in
|
||||||
(* type -> default value *)
|
(* type -> default value *)
|
||||||
let ty_tbl = Ty.Tbl.create 8 in
|
let ty_tbl = Ty.Tbl.create 8 in
|
||||||
Term.Tbl.values cc.tbl
|
Term.Tbl.values cc.tbl
|
||||||
|
|
@ -552,7 +557,7 @@ let mk_model (cc:t) (m:Model.t) : Model.t =
|
||||||
if not @@ Ty.Tbl.mem ty_tbl (Term.ty t) then (
|
if not @@ Ty.Tbl.mem ty_tbl (Term.ty t) then (
|
||||||
Ty.Tbl.add ty_tbl (Term.ty t) v; (* also give a value to this type *)
|
Ty.Tbl.add ty_tbl (Term.ty t) v; (* also give a value to this type *)
|
||||||
);
|
);
|
||||||
Equiv_class.Tbl.add t_tbl r v
|
N.Tbl.add t_tbl r v
|
||||||
));
|
));
|
||||||
(* now map every uninterpreted term to its representative's value, and
|
(* now map every uninterpreted term to its representative's value, and
|
||||||
create function tables *)
|
create function tables *)
|
||||||
|
|
@ -568,20 +573,20 @@ let mk_model (cc:t) (m:Model.t) : Model.t =
|
||||||
else if Cst.is_undefined c && IArray.length args > 0 then (
|
else if Cst.is_undefined c && IArray.length args > 0 then (
|
||||||
(* update signature of [c] *)
|
(* update signature of [c] *)
|
||||||
let ty = Term.ty t in
|
let ty = Term.ty t in
|
||||||
let v = Equiv_class.Tbl.find t_tbl r in
|
let v = N.Tbl.find t_tbl r in
|
||||||
let args =
|
let args =
|
||||||
args
|
args
|
||||||
|> IArray.map (fun t -> Equiv_class.Tbl.find t_tbl @@ find_tn cc t)
|
|> IArray.map (fun t -> N.Tbl.find t_tbl @@ find_tn cc t)
|
||||||
|> IArray.to_list
|
|> IArray.to_list
|
||||||
in
|
in
|
||||||
let ty, l = Cst.Map.get_or c funs ~default:(ty,[]) in
|
let ty, l = Cst.Map.get_or c funs ~default:(ty,[]) in
|
||||||
m, Cst.Map.add c (ty, (args,v)::l) funs
|
m, Cst.Map.add c (ty, (args,v)::l) funs
|
||||||
) else (
|
) else (
|
||||||
let v = Equiv_class.Tbl.find t_tbl r in
|
let v = N.Tbl.find t_tbl r in
|
||||||
Model.add t v m, funs
|
Model.add t v m, funs
|
||||||
)
|
)
|
||||||
| _ ->
|
| _ ->
|
||||||
let v = Equiv_class.Tbl.find t_tbl r in
|
let v = N.Tbl.find t_tbl r in
|
||||||
Model.add t v m, funs)
|
Model.add t v m, funs)
|
||||||
(m,Cst.Map.empty)
|
(m,Cst.Map.empty)
|
||||||
in
|
in
|
||||||
|
|
@ -611,12 +616,12 @@ let mk_model (cc:t) (m:Model.t) : Model.t =
|
||||||
let pp_full out (cc:t) : unit =
|
let pp_full out (cc:t) : unit =
|
||||||
let pp_n out n =
|
let pp_n out n =
|
||||||
let pp_next out n =
|
let pp_next out n =
|
||||||
if n==n.n_root then () else Fmt.fprintf out "@ :next %a" Equiv_class.pp n.n_root in
|
if n==n.n_root then () else Fmt.fprintf out "@ :next %a" N.pp n.n_root in
|
||||||
let pp_root out n =
|
let pp_root out n =
|
||||||
let u = find cc n in if n==u||n.n_root==u then () else Fmt.fprintf out "@ :root %a" Equiv_class.pp u in
|
let u = find cc n in if n==u||n.n_root==u then () else Fmt.fprintf out "@ :root %a" N.pp u in
|
||||||
Fmt.fprintf out "(@[%a%a%a@])" Term.pp n.n_term pp_next n pp_root n
|
Fmt.fprintf out "(@[%a%a%a@])" Term.pp n.n_term pp_next n pp_root n
|
||||||
and pp_sig_e out (s,n) =
|
and pp_sig_e out (s,n) =
|
||||||
Fmt.fprintf out "(@[<1>%a@ -> %a@])" Signature.pp s Equiv_class.pp n
|
Fmt.fprintf out "(@[<1>%a@ -> %a@])" Signature.pp s N.pp n
|
||||||
in
|
in
|
||||||
Fmt.fprintf out
|
Fmt.fprintf out
|
||||||
"(@[cc.state@ (@[<hv>:nodes@ %a@])@ (@[<hv>:sig@ %a@])@])"
|
"(@[cc.state@ (@[<hv>:nodes@ %a@])@ (@[<hv>:sig@ %a@])@])"
|
||||||
|
|
@ -633,7 +638,7 @@ let check_invariants_ (cc:t) =
|
||||||
Term.Tbl.iter
|
Term.Tbl.iter
|
||||||
(fun t n ->
|
(fun t n ->
|
||||||
assert (Term.equal t n.n_term);
|
assert (Term.equal t n.n_term);
|
||||||
assert (not @@ Equiv_class.get_field Equiv_class.field_is_pending n);
|
assert (not @@ N.get_field N.field_is_pending n);
|
||||||
relevant_subterms t
|
relevant_subterms t
|
||||||
(fun u -> assert (Term.Tbl.mem cc.tbl u));
|
(fun u -> assert (Term.Tbl.mem cc.tbl u));
|
||||||
(* check proper signature *)
|
(* check proper signature *)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue