sidekick/src/cc/Congruence_closure.ml
2019-02-22 18:54:56 -06:00

923 lines
33 KiB
OCaml

open CC_types
module type ARG = Congruence_closure_intf.ARG
module type S = Congruence_closure_intf.S
module Bits = CCBitField.Make()
let field_is_pending = Bits.mk_field()
(** true iff the node is in the [cc.pending] queue *)
let field_usr1 = Bits.mk_field()
(** General purpose *)
let field_usr2 = Bits.mk_field()
(** General purpose *)
let () = Bits.freeze()
module Make(A: ARG) = struct
type term = A.Term.t
type term_state = A.Term.state
type lit = A.Lit.t
type fun_ = A.Fun.t
type proof = A.Proof.t
type model = A.Model.t
(** Actions available to the theory *)
type sat_actions = (Msat.void, lit, Msat.void, proof) Msat.acts
module T = A.Term
module Fun = A.Fun
(** A node of the congruence closure.
An equivalence class is represented by its "root" element,
the representative. *)
type node = {
n_term: term;
mutable n_sig0: signature option; (* initial signature *)
mutable n_bits: Bits.t; (* bitfield for various properties *)
mutable n_parents: node Bag.t; (* parent terms of this node *)
mutable n_root: node; (* representative of congruence class (itself if a representative) *)
mutable n_next: node; (* pointer to next element of congruence class *)
mutable n_size: int; (* size of the class *)
mutable n_as_lit: lit option; (* TODO: put into payload? and only in root? *)
mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *)
(* TODO: make a micro theory and move this inside *)
mutable n_tags: (node * explanation) Util.Int_map.t;
(* "distinct" tags (i.e. set of `(distinct t1…tn)` terms this belongs to *)
}
and signature = (fun_, node, node list) view
and explanation_forest_link =
| FL_none
| FL_some of {
next: node;
expl: explanation;
}
(* atomic explanation in the congruence closure *)
and explanation =
| E_reduction (* by pure reduction, tautologically equal *)
| E_merge of node * node
| E_merges of (node * node) list (* caused by these merges *)
| E_congruence of node * node (* caused by normal congruence *)
| E_lit of lit (* because of this literal *)
| E_lits of lit list (* because of this (true) conjunction *)
(* TODO: congruence case (cheaper than "merges") *)
type repr = node
type conflict = lit list
module N = struct
type t = node
let[@inline] equal (n1:t) n2 = T.equal n1.n_term n2.n_term
let[@inline] hash n = T.hash n.n_term
let[@inline] term n = n.n_term
let[@inline] pp out n = T.pp out n.n_term
let[@inline] as_lit n = n.n_as_lit
let make (t:term) : t =
let rec n = {
n_term=t;
n_sig0= None;
n_bits=Bits.empty;
n_parents=Bag.empty;
n_as_lit=None; (* TODO: provide a method to do it *)
n_root=n;
n_expl=FL_none;
n_next=n;
n_size=1;
n_tags=Util.Int_map.empty;
} in
n
let[@inline] is_root (n:node) : bool = n.n_root == n
(* traverse the equivalence class of [n] *)
let iter_class_ (n:node) : node Sequence.t =
fun yield ->
let rec aux u =
yield u;
if u.n_next != n then aux u.n_next
in
aux n
let iter_class n =
assert (is_root n);
iter_class_ n
let[@inline] iter_parents (n:node) : node Sequence.t =
assert (is_root n);
Bag.to_seq n.n_parents
let[@inline] get_field f t = Bits.get f t.n_bits
let[@inline] set_field f b t = t.n_bits <- Bits.set f b t.n_bits
end
module N_tbl = CCHashtbl.Make(N)
module Expl = struct
type t = explanation
let pp out (e:explanation) = match e with
| E_reduction -> Fmt.string out "reduction"
| E_lit lit -> A.Lit.pp out lit
| E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2
| E_lits l -> CCFormat.Dump.list A.Lit.pp out l
| E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b
| E_merges l ->
Format.fprintf out "(@[<hv1>merges@ %a@])"
Fmt.(seq ~sep:(return "@ ") @@ within "[" "]" @@ hvbox @@
pair ~sep:(return " ~@ ") N.pp N.pp)
(Sequence.of_list l)
let mk_reduction : t = E_reduction
let[@inline] mk_congruence n1 n2 : t = E_congruence (n1,n2)
let[@inline] mk_merge a b : t = E_merge (a,b)
let[@inline] mk_merges = function
| [] -> mk_reduction
| [(a,b)] -> mk_merge a b
| l -> E_merges l
let[@inline] mk_lit l : t = E_lit l
let[@inline] mk_lits = function
| [] -> mk_reduction
| [x] -> mk_lit x
| l -> E_lits l
end
(** A signature is a shallow term shape where immediate subterms
are representative *)
module Signature = struct
type t = signature
let equal (s1:t) s2 : bool =
match s1, s2 with
| Bool b1, Bool b2 -> b1=b2
| App_fun (f1,[]), App_fun (f2,[]) -> Fun.equal f1 f2
| App_fun (f1,l1), App_fun (f2,l2) ->
Fun.equal f1 f2 && CCList.equal N.equal l1 l2
| App_ho (f1,l1), App_ho (f2,l2) ->
N.equal f1 f2 && CCList.equal N.equal l1 l2
| If (a1,b1,c1), If (a2,b2,c2) ->
N.equal a1 a2 && N.equal b1 b2 && N.equal c1 c2
| Eq (a1,b1), Eq (a2,b2) ->
N.equal a1 a2 && N.equal b1 b2
| Opaque u1, Opaque u2 -> N.equal u1 u2
| Bool _, _ | App_fun _, _ | App_ho _, _ | If _, _
| Eq _, _ | Opaque _, _
-> false
let hash (s:t) : int =
let module H = CCHash in
match s with
| Bool b -> H.combine2 10 (H.bool b)
| App_fun (f, l) -> H.combine3 20 (Fun.hash f) (H.list N.hash l)
| App_ho (f, l) -> H.combine3 30 (N.hash f) (H.list N.hash l)
| Eq (a,b) -> H.combine3 40 (N.hash a) (N.hash b)
| Opaque u -> H.combine2 50 (N.hash u)
| If (a,b,c) -> H.combine4 60 (N.hash a)(N.hash b)(N.hash c)
let pp out = function
| Bool b -> Fmt.bool out b
| App_fun (f, []) -> Fun.pp out f
| App_fun (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Fun.pp f (Util.pp_list N.pp) l
| App_ho (f, []) -> N.pp out f
| App_ho (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" N.pp f (Util.pp_list N.pp) l
| Opaque t -> N.pp out t
| Eq (a,b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" N.pp a N.pp b
| If (a,b,c) -> Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" N.pp a N.pp b N.pp c
end
module Sig_tbl = CCHashtbl.Make(Signature)
module T_tbl = CCHashtbl.Make(T)
type combine_task =
| CT_merge of node * node * explanation
| CT_distinct of node list * int * explanation
type t = {
tst: term_state;
tbl: node T_tbl.t;
(* internalization [term -> node] *)
signatures_tbl : node Sig_tbl.t;
(* map a signature to the corresponding node in some equivalence class.
A signature is a [term_cell] in which every immediate subterm
that participates in the congruence/evaluation relation
is normalized (i.e. is its own representative).
The critical property is that all members of an equivalence class
that have the same "shape" (including head symbol)
have the same signature *)
pending: node Vec.t;
combine: combine_task Vec.t;
undo: (unit -> unit) Backtrack_stack.t;
on_merge: (repr -> repr -> explanation -> unit) option;
mutable ps_lits: lit list; (* TODO: thread it around instead? *)
(* proof state *)
ps_queue: (node*node) Vec.t;
(* pairs to explain *)
true_ : node lazy_t;
false_ : node lazy_t;
}
(* TODO: an additional union-find to keep track, for each term,
of the terms they are known to be equal to, according
to the current explanation. That allows not to prove some equality
several times.
See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *)
let[@inline] size_ (r:repr) = r.n_size
let[@inline] true_ cc = Lazy.force cc.true_
let[@inline] false_ cc = Lazy.force cc.false_
let[@inline] on_backtrack cc f : unit =
Backtrack_stack.push_if_nonzero_level cc.undo f
(* check if [t] is in the congruence closure.
Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *)
let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t
(* find representative, recursively *)
let[@unroll 2] rec find_rec (n:node) : repr =
if n==n.n_root then (
n
) else (
let root = find_rec n.n_root in
if root != n.n_root then (
n.n_root <- root; (* path compression *)
);
root
)
(* non-recursive, inlinable function for [find] *)
let[@inline] find_ (n:node) : repr =
if n == n.n_root then n else find_rec n.n_root
let[@inline] same_class (n1:node)(n2:node): bool =
N.equal (find_ n1) (find_ n2)
let[@inline] find _ n = find_ n
(* print full state *)
let pp_full out (cc:t) : unit =
let pp_next out n =
Fmt.fprintf out "@ :next %a" N.pp n.n_next in
let pp_root out n =
if N.is_root n then Fmt.string out " :is-root" else Fmt.fprintf out "@ :root %a" N.pp n.n_root in
let pp_expl out n = match n.n_expl with
| FL_none -> ()
| FL_some e ->
Fmt.fprintf out " (@[:forest %a :expl %a@])" N.pp e.next Expl.pp e.expl
in
let pp_n out n =
Fmt.fprintf out "(@[%a%a%a%a@])" T.pp n.n_term pp_root n pp_next n pp_expl n
and pp_sig_e out (s,n) =
Fmt.fprintf out "(@[<1>%a@ ~~> %a%a@])" Signature.pp s N.pp n pp_root n
in
Fmt.fprintf out
"(@[@{<yellow>cc.state@}@ (@[<hv>:nodes@ %a@])@ (@[<hv>:sig-tbl@ %a@])@])"
(Util.pp_seq ~sep:" " pp_n) (T_tbl.values cc.tbl)
(Util.pp_seq ~sep:" " pp_sig_e) (Sig_tbl.to_seq cc.signatures_tbl)
(* compute up-to-date signature *)
let update_sig (s:signature) : Signature.t =
CC_types.map_view s
~f_f:(fun x->x)
~f_t:find_
~f_ts:(List.map find_)
(* find whether the given (parent) term corresponds to some signature
in [signatures_] *)
let[@inline] find_signature cc (s:signature) : repr option =
Sig_tbl.get cc.signatures_tbl s
let add_signature cc (s:signature) (n:node) : unit =
(* add, but only if not present already *)
match Sig_tbl.find cc.signatures_tbl s with
| exception Not_found ->
Log.debugf 15
(fun k->k "(@[cc.add-sig@ %a@ ~~> %a@])" Signature.pp s N.pp n);
on_backtrack cc (fun () -> Sig_tbl.remove cc.signatures_tbl s);
Sig_tbl.add cc.signatures_tbl s n;
| r' ->
assert (same_class n r');
()
let push_pending cc t : unit =
if not @@ N.get_field field_is_pending t then (
Log.debugf 5 (fun k->k "(@[<hv1>cc.push_pending@ %a@])" N.pp t);
N.set_field field_is_pending true t;
Vec.push cc.pending t
)
let push_combine cc t u e : unit =
Log.debugf 5
(fun k->k "(@[<hv1>cc.push_combine@ %a ~@ %a@ :expl %a@])"
N.pp t N.pp u Expl.pp e);
Vec.push cc.combine @@ CT_merge (t,u,e)
(* re-root the explanation tree of the equivalence class of [n]
so that it points to [n].
postcondition: [n.n_expl = None] *)
let rec reroot_expl (cc:t) (n:node): unit =
let old_expl = n.n_expl in
begin match old_expl with
| FL_none -> () (* already root *)
| FL_some {next=u; expl=e_n_u} ->
reroot_expl cc u;
u.n_expl <- FL_some {next=n; expl=e_n_u};
n.n_expl <- FL_none;
end
let raise_conflict (cc:t) (acts:sat_actions) (e:conflict): _ =
(* clear tasks queue *)
Vec.iter (N.set_field field_is_pending false) cc.pending;
Vec.clear cc.pending;
Vec.clear cc.combine;
let c = List.rev_map A.Lit.neg e in
acts.Msat.acts_raise_conflict c A.Proof.default
let[@inline] all_classes cc : repr Sequence.t =
T_tbl.values cc.tbl
|> Sequence.filter N.is_root
(* TODO: use markers and lockstep iteration instead *)
(* distance from [t] to its root in the proof forest *)
let[@inline][@unroll 2] rec distance_to_root (n:node): int = match n.n_expl with
| FL_none -> 0
| FL_some {next=t'; _} -> 1 + distance_to_root t'
(* TODO: bool flag on nodes + stepwise progress + cleanup *)
(* find the closest common ancestor of [a] and [b] in the proof forest *)
let find_common_ancestor (a:node) (b:node) : node =
let d_a = distance_to_root a in
let d_b = distance_to_root b in
(* drop [n] nodes in the path from [t] to its root *)
let rec drop_ n t =
if n=0 then t
else match t.n_expl with
| FL_none -> assert false
| FL_some {next=t'; _} -> drop_ (n-1) t'
in
(* reduce to the problem where [a] and [b] have the same distance to root *)
let a, b =
if d_a > d_b then drop_ (d_a-d_b) a, b
else if d_a < d_b then a, drop_ (d_b-d_a) b
else a, b
in
(* traverse stepwise until a==b *)
let rec aux_same_dist a b =
if a==b then a
else match a.n_expl, b.n_expl with
| FL_none, _ | _, FL_none -> assert false
| FL_some {next=a'; _}, FL_some {next=b'; _} -> aux_same_dist a' b'
in
aux_same_dist a b
let[@inline] ps_add_obligation (cc:t) a b = Vec.push cc.ps_queue (a,b)
let[@inline] ps_add_lit ps l = ps.ps_lits <- l :: ps.ps_lits
(* TODO: remove *)
let ps_clear (cc:t) =
cc.ps_lits <- [];
Vec.clear cc.ps_queue;
()
(* TODO: turn this into a fold? *)
(* decompose explanation [e] of why [n1 = n2] *)
let decompose_explain cc (e:explanation) : unit =
Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e);
begin match e with
| E_reduction -> ()
| E_congruence (n1, n2) ->
begin match n1.n_sig0, n2.n_sig0 with
| Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) ->
assert (Fun.equal f1 f2);
assert (List.length a1 = List.length a2);
List.iter2 (ps_add_obligation cc) a1 a2;
| Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) ->
assert (List.length a1 = List.length a2);
ps_add_obligation cc f1 f2;
List.iter2 (ps_add_obligation cc) a1 a2;
| Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) ->
ps_add_obligation cc a1 a2;
ps_add_obligation cc b1 b2;
ps_add_obligation cc c1 c2;
| _ ->
assert false
end
| E_lit lit -> ps_add_lit cc lit
| E_lits l -> List.iter (ps_add_lit cc) l
| E_merge (a,b) -> ps_add_obligation cc a b
| E_merges l ->
(* need to explain each merge in [l] *)
List.iter (fun (t,u) -> ps_add_obligation cc t u) l
end
(* explain why [a = parent_a], where [a -> ... -> parent_a] in the
proof forest *)
let explain_along_path ps (a:node) (parent_a:node) : unit =
let rec aux n =
if n != parent_a then (
match n.n_expl with
| FL_none -> assert false
| FL_some {next=next_n; expl=expl} ->
decompose_explain ps expl;
(* now prove [next_n = parent_a] *)
aux next_n
)
in aux a
(* find explanation *)
let explain_loop (cc : t) : lit list =
while not (Vec.is_empty cc.ps_queue) do
let a, b = Vec.pop cc.ps_queue in
Log.debugf 5
(fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b);
assert (N.equal (find_ a) (find_ b));
let c = find_common_ancestor a b in
explain_along_path cc a c;
explain_along_path cc b c;
done;
cc.ps_lits
(* TODO: do not use ps_lits anymore *)
let explain_eq_n ?(init=[]) cc (n1:node) (n2:node) : lit list =
ps_clear cc;
cc.ps_lits <- init;
ps_add_obligation cc n1 n2;
explain_loop cc
let explain_unfold ?(init=[]) cc (e:explanation) : lit list =
ps_clear cc;
cc.ps_lits <- init;
decompose_explain cc e;
explain_loop cc
(* add [tag] to [n], indicating that [n] is distinct from all the other
nodes tagged with [tag]
precond: [n] is a representative *)
let add_tag_n cc (n:node) (tag:int) (expl:explanation) : unit =
assert (N.is_root n);
if not (Util.Int_map.mem tag n.n_tags) then (
on_backtrack cc
(fun () -> n.n_tags <- Util.Int_map.remove tag n.n_tags);
n.n_tags <- Util.Int_map.add tag (n,expl) n.n_tags;
)
(* add a term *)
let [@inline] rec add_term_rec_ cc t : node =
try T_tbl.find cc.tbl t
with Not_found -> add_new_term_ cc t
(* add [t] to [cc] when not present already *)
and add_new_term_ cc (t:term) : node =
assert (not @@ mem cc t);
Log.debugf 15 (fun k->k "(@[cc.add-term@ %a@])" T.pp t);
let n = N.make t in
(* register sub-terms, add [t] to their parent list, and return the
corresponding initial signature *)
let sig0 = compute_sig0 cc n in
n.n_sig0 <- sig0;
(* remove term when we backtrack *)
on_backtrack cc
(fun () ->
Log.debugf 15 (fun k->k "(@[cc.remove-term@ %a@])" T.pp t);
T_tbl.remove cc.tbl t);
(* add term to the table *)
T_tbl.add cc.tbl t n;
if CCOpt.is_some sig0 then (
(* [n] might be merged with other equiv classes *)
push_pending cc n;
);
n
(* compute the initial signature of the given node *)
and compute_sig0 (self:t) (n:node) : Signature.t option =
(* add sub-term to [cc], and register [n] to its parents *)
let deref_sub (u:term) : node =
let sub = add_term_rec_ self u in
(* add [n] to [sub.root]'s parent list *)
begin
let sub = find_ sub in
let old_parents = sub.n_parents in
on_backtrack self (fun () -> sub.n_parents <- old_parents);
sub.n_parents <- Bag.cons n sub.n_parents;
end;
sub
in
let[@inline] return x = Some x in
match T.cc_view n.n_term with
| Bool _ | Opaque _ -> None
| Eq (a,b) ->
let a = deref_sub a in
let b = deref_sub b in
return @@ Eq (a,b)
| App_fun (f, args) ->
let args = args |> Sequence.map deref_sub |> Sequence.to_list in
if args<>[] then (
return @@ App_fun (f, args)
) else None
| App_ho (f, args) ->
let args = args |> Sequence.map deref_sub |> Sequence.to_list in
return @@ App_ho (deref_sub f, args)
| If (a,b,c) ->
return @@ If (deref_sub a, deref_sub b, deref_sub c)
let[@inline] add_term cc t : node = add_term_rec_ cc t
let[@inline] add_term' cc t : unit = ignore (add_term_rec_ cc t : node)
let set_as_lit cc (n:node) (lit:lit) : unit =
match n.n_as_lit with
| Some _ -> ()
| None ->
Log.debugf 15 (fun k->k "(@[cc.set-as-lit@ %a@ %a@])" N.pp n A.Lit.pp lit);
on_backtrack cc (fun () -> n.n_as_lit <- None);
n.n_as_lit <- Some lit
(* Checks if [ra] and [~into] have compatible normal forms and can
be merged w.r.t. the theories.
Side effect: also pushes sub-tasks *)
let notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit =
assert (N.is_root rb);
match cc.on_merge with
| Some f -> f ra rb e
| None -> ()
let[@inline] n_is_bool (self:t) n : bool =
N.equal n (true_ self) || N.equal n (false_ self)
(* main CC algo: add terms from [pending] to the signature table,
check for collisions *)
let rec update_tasks (cc:t) (acts:sat_actions) : unit =
while not (Vec.is_empty cc.pending && Vec.is_empty cc.combine) do
while not @@ Vec.is_empty cc.pending do
task_pending_ cc (Vec.pop cc.pending);
done;
while not @@ Vec.is_empty cc.combine do
task_combine_ cc acts (Vec.pop cc.combine);
done;
done
and task_pending_ cc (n:node) : unit =
N.set_field field_is_pending false n;
(* check if some parent collided *)
begin match n.n_sig0 with
| None -> () (* no-op *)
| Some (Eq (a,b)) ->
(* if [a=b] is now true, merge [(a=b)] and [true] *)
if same_class a b then (
let expl = Expl.mk_merge a b in
push_combine cc n (true_ cc) expl
)
| Some s0 ->
(* update the signature by using [find] on each sub-node *)
let s = update_sig s0 in
match find_signature cc s with
| None ->
(* add to the signature table [sig(n) --> n] *)
add_signature cc s n
| Some u when n == u -> ()
| Some u ->
(* [t1] and [t2] must be applications of the same symbol to
arguments that are pairwise equal *)
assert (n != u);
let expl = Expl.mk_congruence n u in
push_combine cc n u expl
end
(* TODO: remove, once we have moved distinct to a theory *)
and[@inline] task_combine_ cc acts = function
| CT_merge (a,b,e_ab) -> task_merge_ cc acts a b e_ab
| CT_distinct (l,tag,e) -> task_distinct_ cc acts l tag e
(* main CC algo: merge equivalence classes in [st.combine].
@raise Exn_unsat if merge fails *)
and task_merge_ cc acts a b e_ab : unit =
let ra = find_ a in
let rb = find_ b in
if not @@ N.equal ra rb then (
assert (N.is_root ra);
assert (N.is_root rb);
(* check we're not merging [true] and [false] *)
if (N.equal ra (true_ cc) && N.equal rb (false_ cc)) ||
(N.equal rb (true_ cc) && N.equal ra (false_ cc)) then (
Log.debugf 5
(fun k->k "(@[<hv>cc.merge.true_false_conflict@ @[:r1 %a@]@ @[:r2 %a@]@ :e_ab %a@])"
N.pp ra N.pp rb Expl.pp e_ab);
let lits = explain_unfold cc e_ab in
let lits = explain_eq_n ~init:lits cc a ra in
let lits = explain_eq_n ~init:lits cc b rb in
raise_conflict cc acts lits
);
(* We will merge [r_from] into [r_into].
we try to ensure that [size ra <= size rb] in general, but always
keep values as representative *)
let r_from, r_into =
if n_is_bool cc ra then rb, ra
else if n_is_bool cc rb then ra, rb
else if size_ ra > size_ rb then rb, ra
else ra, rb
in
(* TODO: instead call micro theories, including "distinct" *)
(* update set of tags the new node cannot be equal to *)
let new_tags =
Util.Int_map.union
(fun _i (n1,e1) (n2,e2) ->
(* both maps contain same tag [_i]. conflict clause:
[e1 & e2 & e_ab] impossible *)
Log.debugf 5
(fun k->k "(@[<hv>cc.merge.distinct_conflict@ :tag %d@ \
@[:r1 %a@ :e1 %a@]@ @[:r2 %a@ :e2 %a@]@ :e_ab %a@])"
_i N.pp n1 Expl.pp e1
N.pp n2 Expl.pp e2 Expl.pp e_ab);
let lits = explain_unfold cc e1 in
let lits = explain_unfold ~init:lits cc e2 in
let lits = explain_unfold ~init:lits cc e_ab in
let lits = explain_eq_n ~init:lits cc a n1 in
let lits = explain_eq_n ~init:lits cc b n2 in
raise_conflict cc acts lits)
ra.n_tags rb.n_tags
in
(* when merging terms with [true] or [false], possibly propagate them to SAT *)
let merge_bool r1 t1 r2 t2 =
if N.equal r1 (true_ cc) then (
propagate_bools cc acts r2 t2 r1 t1 e_ab true
) else if N.equal r1 (false_ cc) then (
propagate_bools cc acts r2 t2 r1 t1 e_ab false
)
in
merge_bool ra a rb b;
merge_bool rb b ra a;
(* perform [union r_from r_into] *)
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
begin
(* parents might have a different signature, check for collisions *)
N.iter_parents r_from
(fun parent -> push_pending cc parent);
(* for each node in [r_from]'s class, make it point to [r_into] *)
N.iter_class r_from
(fun u ->
assert (u.n_root == r_from);
u.n_root <- r_into);
(* now merge the classes *)
let r_into_old_tags = r_into.n_tags in
let r_into_old_next = r_into.n_next in
let r_from_old_next = r_from.n_next in
let r_into_old_parents = r_into.n_parents in
r_into.n_parents <- Bag.append r_into.n_parents r_from.n_parents;
(* on backtrack, unmerge classes and restore the pointers to [r_from] *)
on_backtrack cc
(fun () ->
Log.debugf 15
(fun k->k "(@[cc.undo_merge@ :from %a :into %a@])"
N.pp r_from N.pp r_into);
r_into.n_next <- r_into_old_next;
r_from.n_next <- r_from_old_next;
r_into.n_tags <- r_into_old_tags;
r_into.n_parents <- r_into_old_parents;
N.iter_class_ r_from (fun u -> u.n_root <- r_from);
);
r_into.n_tags <- new_tags;
(* swap [into.next] and [from.next], merging the classes *)
r_into.n_next <- r_from_old_next;
r_from.n_next <- r_into_old_next;
end;
(* update explanations (a -> b), arbitrarily.
Note that here we merge the classes by adding a bridge between [a]
and [b], not their roots. *)
begin
reroot_expl cc a;
assert (a.n_expl = FL_none);
(* on backtracking, link may be inverted, but we delete the one
that bridges between [a] and [b] *)
on_backtrack cc
(fun () -> match a.n_expl, b.n_expl with
| FL_some e, _ when N.equal e.next b -> a.n_expl <- FL_none
| _, FL_some e when N.equal e.next a -> b.n_expl <- FL_none
| _ -> assert false);
a.n_expl <- FL_some {next=b; expl=e_ab};
end;
(* notify listeners of the merge *)
notify_merge cc r_from ~into:r_into e_ab;
)
and task_distinct_ cc acts (l:node list) tag expl : unit =
let l = List.map (fun n -> n, find_ n) l in
let coll =
Sequence.diagonal_l l
|> Sequence.find_pred (fun ((_,r1),(_,r2)) -> N.equal r1 r2)
in
begin match coll with
| Some ((n1,_r1),(n2,_r2)) ->
(* two classes are already equal *)
Log.debugf 5
(fun k->k "(@[cc.distinct.conflict@ %a = %a@ :expl %a@])" N.pp n1 N.pp
n2 Expl.pp expl);
let lits = explain_unfold cc expl in
raise_conflict cc acts lits
| None ->
(* put a tag on all equivalence classes, that will make their merge fail *)
List.iter (fun (_,n) -> add_tag_n cc n tag expl) l
end
(* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1]
in the equiv class of [r1] that is a known literal back to the SAT solver
and which is not the one initially merged.
We can explain the propagation with [u1 = t1 =e= t2 = r2==bool] *)
and propagate_bools cc acts r1 t1 r2 t2 (e_12:explanation) sign : unit =
(* explanation for [t1 =e= t2 = r2] *)
let half_expl = lazy (
let expl = explain_unfold cc e_12 in
explain_eq_n ~init:expl cc r2 t2
) in
(* TODO: flag per class, `or`-ed on merge, to indicate if the class
contains at least one lit *)
N.iter_class r1
(fun u1 ->
(* propagate if:
- [u1] is a proper literal
- [t2 != r2], because that can only happen
after an explicit merge (no way to obtain that by propagation)
*)
match N.as_lit u1 with
| Some lit when not (N.equal r2 t2) ->
let lit = if sign then lit else A.Lit.neg lit in (* apply sign *)
Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" A.Lit.pp lit);
(* complete explanation with the [u1=t1] chunk *)
let expl () =
let e = explain_eq_n ~init:(Lazy.force half_expl) cc u1 t1 in
e, A.Proof.default in
let reason = Msat.Consequence expl in
acts.Msat.acts_propagate lit reason
| _ -> ())
let check_invariants_ (cc:t) =
Log.debug 5 "(cc.check-invariants)";
Log.debugf 15 (fun k-> k "%a" pp_full cc);
assert (T.equal (T.bool cc.tst true) (true_ cc).n_term);
assert (T.equal (T.bool cc.tst false) (false_ cc).n_term);
assert (not @@ same_class (true_ cc) (false_ cc));
assert (Vec.is_empty cc.combine);
assert (Vec.is_empty cc.pending);
(* check that subterms are internalized *)
T_tbl.iter
(fun t n ->
assert (T.equal t n.n_term);
assert (not @@ N.get_field field_is_pending n);
assert (N.equal n.n_root n.n_next.n_root);
(* check proper signature.
note that some signatures in the sig table can be obsolete (they
were not removed) but there must be a valid, up-to-date signature for
each term *)
begin match CCOpt.map update_sig n.n_sig0 with
| None -> ()
| Some s ->
Log.debugf 15 (fun k->k "(@[cc.check-sig@ %a@ :sig %a@])" T.pp t Signature.pp s);
(* add, but only if not present already *)
begin match Sig_tbl.find cc.signatures_tbl s with
| exception Not_found -> assert false
| repr_s -> assert (same_class n repr_s)
end
end;
)
cc.tbl;
()
let[@inline] check_invariants (cc:t) : unit =
if Util._CHECK_INVARIANTS then check_invariants_ cc
let add_seq cc seq =
seq (fun t -> ignore @@ add_term_rec_ cc t);
()
let[@inline] push_level (self:t) : unit =
Backtrack_stack.push_level self.undo
let pop_levels (self:t) n : unit =
Vec.iter (N.set_field field_is_pending false) self.pending;
Vec.clear self.pending;
Vec.clear self.combine;
Log.debugf 15
(fun k->k "(@[cc.pop-levels %d@ :n-lvls %d@])" n (Backtrack_stack.n_levels self.undo));
Backtrack_stack.pop_levels self.undo n ~f:(fun f -> f());
()
(* assert that this boolean literal holds.
if a lit is [= a b], merge [a] and [b];
if it's [distinct a1…an], make them distinct, etc. etc. *)
let assert_lit cc lit : unit =
let t = A.Lit.term lit in
Log.debugf 5 (fun k->k "(@[cc.assert_lit@ %a@])" A.Lit.pp lit);
let sign = A.Lit.sign lit in
begin match T.cc_view t with
| Eq (a,b) when sign ->
let a = add_term cc a in
let b = add_term cc b in
(* merge [a] and [b] *)
push_combine cc a b (Expl.mk_lit lit)
| _ ->
(* equate t and true/false *)
let rhs = if sign then true_ cc else false_ cc in
let n = add_term cc t in
(* TODO: ensure that this is O(1).
basically, just have [n] point to true/false and thus acquire
the corresponding value, so its superterms (like [ite]) can evaluate
properly *)
push_combine cc n rhs (Expl.mk_lit lit)
end
let[@inline] assert_lits cc lits : unit =
Sequence.iter (assert_lit cc) lits
let assert_eq cc t1 t2 (e:lit list) : unit =
let expl = Expl.mk_lits e in
let n1 = add_term cc t1 in
let n2 = add_term cc t2 in
push_combine cc n1 n2 expl
(* generative tag used to annotate classes that can't be merged *)
let distinct_tag_ = ref 0
let assert_distinct cc (l:term list) ~neq:_ (lit:lit) : unit =
assert (match l with[] | [_] -> false | _ -> true);
let tag = CCRef.get_then_incr distinct_tag_ in
Log.debugf 5
(fun k->k "(@[cc.assert_distinct@ (@[%a@])@ :tag %d@])" (Util.pp_list T.pp) l tag);
let l = List.map (add_term cc) l in
Vec.push cc.combine @@ CT_distinct (l, tag, Expl.mk_lit lit)
let create ?on_merge ?(size=`Big) (tst:term_state) : t =
let size = match size with `Small -> 128 | `Big -> 2048 in
let rec cc = {
tst;
tbl = T_tbl.create size;
signatures_tbl = Sig_tbl.create size;
on_merge;
pending=Vec.create();
combine=Vec.create();
ps_lits=[];
undo=Backtrack_stack.create();
ps_queue=Vec.create();
true_;
false_;
} and true_ = lazy (
add_term cc (T.bool tst true)
) and false_ = lazy (
add_term cc (T.bool tst false)
)
in
ignore (Lazy.force true_ : node);
ignore (Lazy.force false_ : node);
cc
let[@inline] find_t cc t : repr =
let n = T_tbl.find cc.tbl t in
find_ n
let[@inline] check cc acts : unit =
Log.debug 5 "(cc.check)";
update_tasks cc acts
(* model: map each uninterpreted equiv class to some ID *)
let mk_model (cc:t) (m:A.Model.t) : A.Model.t =
let module Model = A.Model in
let module Value = A.Value in
Log.debugf 15 (fun k->k "(@[cc.mk-model@ %a@])" pp_full cc);
let t_tbl = N_tbl.create 32 in
(* populate [repr -> value] table *)
T_tbl.values cc.tbl
(fun r ->
if N.is_root r then (
(* find a value in the class, if any *)
let v =
N.iter_class r
|> Sequence.find_map (fun n -> Model.eval m n.n_term)
in
let v = match v with
| Some v -> v
| None ->
if same_class r (true_ cc) then Value.true_
else if same_class r (false_ cc) then Value.false_
else Value.fresh r.n_term
in
N_tbl.add t_tbl r v
));
(* now map every term to its representative's value *)
let pairs =
T_tbl.values cc.tbl
|> Sequence.map
(fun n ->
let r = find_ n in
let v =
try N_tbl.find t_tbl r
with Not_found ->
Error.errorf "didn't allocate a value for repr %a" N.pp r
in
n.n_term, v)
in
let m = Sequence.fold (fun m (t,v) -> Model.add t v m) m pairs in
Log.debugf 5 (fun k->k "(@[cc.model@ %a@])" Model.pp m);
m
end