feat(cc): split sub-library sidekick.cc, make it fully functorized

This commit is contained in:
Simon Cruanes 2019-02-09 20:24:29 -06:00
parent de1653bdcc
commit a463dbb4b5
39 changed files with 1558 additions and 1237 deletions

112
src/cc/CC_types.ml Normal file
View file

@ -0,0 +1,112 @@
(** {1 Types used by the congruence closure} *)
type ('f, 't, 'ts) view =
| Bool of bool
| App_fun of 'f * 'ts
| App_ho of 't * 'ts
| If of 't * 't * 't
| Eq of 't * 't
| Opaque of 't (* do not enter *)
let[@inline] map_view ~f_f ~f_t ~f_ts (v:_ view) : _ view =
match v with
| Bool b -> Bool b
| App_fun (f, args) -> App_fun (f_f f, f_ts args)
| App_ho (f, args) -> App_ho (f_t f, f_ts args)
| If (a,b,c) -> If (f_t a, f_t b, f_t c)
| Eq (a,b) -> Eq (f_t a, f_t b)
| Opaque t -> Opaque (f_t t)
let iter_view ~f_f ~f_t ~f_ts (v:_ view) : unit =
match v with
| Bool _ -> ()
| App_fun (f, args) -> f_f f; f_ts args
| App_ho (f, args) -> f_t f; f_ts args
| If (a,b,c) -> f_t a; f_t b; f_t c;
| Eq (a,b) -> f_t a; f_t b
| Opaque t -> f_t t
module type TERM = sig
module Fun : sig
type t
val equal : t -> t -> bool
val hash : t -> int
val pp : t Fmt.printer
end
module Term : sig
type t
val equal : t -> t -> bool
val hash : t -> int
val pp : t Fmt.printer
type state
val bool : state -> bool -> t
(** View the term through the lens of the congruence closure *)
val cc_view : t -> (Fun.t, t, t Sequence.t) view
end
end
module type TERM_LIT = sig
include TERM
module Lit : sig
type t
val neg : t -> t
val equal : t -> t -> bool
val hash : t -> int
val pp : t Fmt.printer
val sign : t -> bool
val term : t -> Term.t
end
end
module type FULL = sig
include TERM_LIT
module Proof : sig
type t
val pp : t Fmt.printer
val default : t
(* TODO: to give more details
val cc_lemma : unit -> t
*)
end
module Ty : sig
type t
val equal : t -> t -> bool
val hash : t -> int
val pp : t Fmt.printer
end
module Value : sig
type t
val pp : t Fmt.printer
val fresh : Term.t -> t
val true_ : t
val false_ : t
end
module Model : sig
type t
val pp : t Fmt.printer
val eval : t -> Term.t -> Value.t option
(** Evaluate the term in the current model *)
val add : Term.t -> Value.t -> t -> t
end
end
(* TODO: micro theory *)

View file

@ -0,0 +1,939 @@
open CC_types
module type ARG = Congruence_closure_intf.ARG
module type S = Congruence_closure_intf.S
module Bits = CCBitField.Make()
let field_is_pending = Bits.mk_field()
(** true iff the node is in the [cc.pending] queue *)
let () = Bits.freeze()
type payload = Congruence_closure_intf.payload = ..
module Make(A: ARG) = struct
type term = A.Term.t
type term_state = A.Term.state
type lit = A.Lit.t
type fun_ = A.Fun.t
type proof = A.Proof.t
type value = A.Value.t
type model = A.Model.t
(** Actions available to the theory *)
type sat_actions = (Msat.void, lit, Msat.void, proof) Msat.acts
module T = A.Term
module Fun = A.Fun
(** A node of the congruence closure.
An equivalence class is represented by its "root" element,
the representative. *)
type node = {
n_term: term;
mutable n_sig0: signature option; (* initial signature *)
mutable n_bits: Bits.t; (* bitfield for various properties *)
mutable n_parents: node Bag.t; (* parent terms of this node *)
mutable n_root: node; (* representative of congruence class (itself if a representative) *)
mutable n_next: node; (* pointer to next element of congruence class *)
mutable n_size: int; (* size of the class *)
mutable n_as_lit: lit option;
mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *)
mutable n_payload: payload list; (* list of theory payloads *)
(* TODO: make a micro theory and move this inside *)
mutable n_tags: (node * explanation) Util.Int_map.t;
(* "distinct" tags (i.e. set of `(distinct t1…tn)` terms this belongs to *)
}
and signature = (fun_, node, node list) view
and explanation_forest_link =
| FL_none
| FL_some of {
next: node;
expl: explanation;
}
(* atomic explanation in the congruence closure *)
and explanation =
| E_reduction (* by pure reduction, tautologically equal *)
| E_merges of (node * node) list (* caused by these merges *)
| E_lit of lit (* because of this literal *)
| E_lits of lit list (* because of this (true) conjunction *)
(* TODO: congruence case (cheaper than "merges") *)
type repr = node
type conflict = lit list
module N = struct
type t = node
let[@inline] equal (n1:t) n2 = T.equal n1.n_term n2.n_term
let[@inline] hash n = T.hash n.n_term
let[@inline] term n = n.n_term
let[@inline] payload n = n.n_payload
let[@inline] pp out n = T.pp out n.n_term
let[@inline] as_lit n = n.n_as_lit
let make (t:term) : t =
let rec n = {
n_term=t;
n_sig0= None;
n_bits=Bits.empty;
n_parents=Bag.empty;
n_as_lit=None; (* TODO: provide a method to do it *)
n_root=n;
n_expl=FL_none;
n_payload=[];
n_next=n;
n_size=1;
n_tags=Util.Int_map.empty;
} in
n
type nonrec payload = payload = ..
let set_payload ?(can_erase=fun _->false) n e =
let rec aux = function
| [] -> [e]
| e' :: tail when can_erase e' -> e :: tail
| e' :: tail -> e' :: aux tail
in
n.n_payload <- aux n.n_payload
let payload_find ~f:p n =
let[@unroll 2] rec aux = function
| [] -> None
| e1 :: tail ->
match p e1 with
| Some _ as res -> res
| None -> aux tail
in
aux n.n_payload
let payload_pred ~f:p n =
begin match n.n_payload with
| [] -> false
| e :: _ when p e -> true
| _ :: e :: _ when p e -> true
| _ :: _ :: e :: _ when p e -> true
| l -> List.exists p l
end
let[@inline] get_field f t = Bits.get f t.n_bits
let[@inline] set_field f b t = t.n_bits <- Bits.set f b t.n_bits
end
module N_tbl = CCHashtbl.Make(N)
module Expl = struct
type t = explanation
let equal (a:t) b =
match a, b with
| E_merges l1, E_merges l2 ->
CCList.equal (CCPair.equal N.equal N.equal) l1 l2
| E_reduction, E_reduction -> true
| E_lit l1, E_lit l2 -> A.Lit.equal l1 l2
| E_lits l1, E_lits l2 -> CCList.equal A.Lit.equal l1 l2
| E_merges _, _ | E_lit _, _ | E_lits _, _ | E_reduction, _
-> false
let hash (a:t) : int =
let module H = CCHash in
match a with
| E_lit lit -> H.combine2 10 (A.Lit.hash lit)
| E_lits l ->
H.combine2 20 (H.list A.Lit.hash l)
| E_merges l ->
H.combine2 30 (H.list (H.pair N.hash N.hash) l)
| E_reduction -> H.int 40
let pp out (e:explanation) = match e with
| E_reduction -> Fmt.string out "reduction"
| E_lit lit -> A.Lit.pp out lit
| E_lits l -> CCFormat.Dump.list A.Lit.pp out l
| E_merges l ->
Format.fprintf out "(@[<hv1>merges@ %a@])"
Fmt.(seq ~sep:(return "@ ") @@ within "[" "]" @@ hvbox @@
pair ~sep:(return " ~@ ") N.pp N.pp)
(Sequence.of_list l)
let[@inline] mk_merges l : t = E_merges l
let[@inline] mk_lit l : t = E_lit l
let[@inline] mk_lits = function [x] -> mk_lit x | l -> E_lits l
let mk_reduction : t = E_reduction
end
(** A signature is a shallow term shape where immediate subterms
are representative *)
module Signature = struct
type t = signature
let equal (s1:t) s2 : bool =
match s1, s2 with
| Bool b1, Bool b2 -> b1=b2
| App_fun (f1,[]), App_fun (f2,[]) -> Fun.equal f1 f2
| App_fun (f1,l1), App_fun (f2,l2) ->
Fun.equal f1 f2 && CCList.equal N.equal l1 l2
| App_ho (f1,l1), App_ho (f2,l2) ->
N.equal f1 f2 && CCList.equal N.equal l1 l2
| If (a1,b1,c1), If (a2,b2,c2) ->
N.equal a1 a2 && N.equal b1 b2 && N.equal c1 c2
| Eq (a1,b1), Eq (a2,b2) ->
N.equal a1 a2 && N.equal b1 b2
| Opaque u1, Opaque u2 -> N.equal u1 u2
| Bool _, _ | App_fun _, _ | App_ho _, _ | If _, _
| Eq _, _ | Opaque _, _
-> false
let hash (s:t) : int =
let module H = CCHash in
match s with
| Bool b -> H.combine2 10 (H.bool b)
| App_fun (f, l) -> H.combine3 20 (Fun.hash f) (H.list N.hash l)
| App_ho (f, l) -> H.combine3 30 (N.hash f) (H.list N.hash l)
| Eq (a,b) -> H.combine3 40 (N.hash a) (N.hash b)
| Opaque u -> H.combine2 50 (N.hash u)
| If (a,b,c) -> H.combine4 60 (N.hash a)(N.hash b)(N.hash c)
let pp out = function
| Bool b -> Fmt.bool out b
| App_fun (f, []) -> Fun.pp out f
| App_fun (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Fun.pp f (Util.pp_list N.pp) l
| App_ho (f, []) -> N.pp out f
| App_ho (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" N.pp f (Util.pp_list N.pp) l
| Opaque t -> N.pp out t
| Eq (a,b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" N.pp a N.pp b
| If (a,b,c) -> Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" N.pp a N.pp b N.pp c
end
module Sig_tbl = CCHashtbl.Make(Signature)
module T_tbl = CCHashtbl.Make(T)
type combine_task =
| CT_merge of node * node * explanation
| CT_distinct of node list * int * explanation
type t = {
tst: term_state;
tbl: node T_tbl.t;
(* internalization [term -> node] *)
signatures_tbl : node Sig_tbl.t;
(* map a signature to the corresponding node in some equivalence class.
A signature is a [term_cell] in which every immediate subterm
that participates in the congruence/evaluation relation
is normalized (i.e. is its own representative).
The critical property is that all members of an equivalence class
that have the same "shape" (including head symbol)
have the same signature *)
pending: node Vec.t;
combine: combine_task Vec.t;
undo: (unit -> unit) Backtrack_stack.t;
on_merge: (repr -> repr -> explanation -> unit) option;
mutable ps_lits: lit list; (* TODO: thread it around instead? *)
(* proof state *)
ps_queue: (node*node) Vec.t;
(* pairs to explain *)
true_ : node lazy_t;
false_ : node lazy_t;
}
(* TODO: an additional union-find to keep track, for each term,
of the terms they are known to be equal to, according
to the current explanation. That allows not to prove some equality
several times.
See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *)
let[@inline] is_root_ (n:node) : bool = n.n_root == n
let[@inline] size_ (r:repr) = r.n_size
let[@inline] true_ cc = Lazy.force cc.true_
let[@inline] false_ cc = Lazy.force cc.false_
let[@inline] on_backtrack cc f : unit =
Backtrack_stack.push_if_nonzero_level cc.undo f
(* check if [t] is in the congruence closure.
Invariant: [in_cc t do_cc t => forall u subterm t, in_cc u] *)
let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t
(* find representative, recursively *)
let[@unroll 1] rec find_rec (n:node) : repr =
if n==n.n_root then (
n
) else (
(* TODO: path compression, assuming backtracking restores equiv classes
properly *)
let root = find_rec n.n_root in
root
)
(* traverse the equivalence class of [n] *)
let iter_class_ (n:node) : node Sequence.t =
fun yield ->
let rec aux u =
yield u;
if u.n_next != n then aux u.n_next
in
aux n
(* non-recursive, inlinable function for [find] *)
let[@inline] find_ (n:node) : repr =
if n == n.n_root then n else find_rec n.n_root
let[@inline] same_class (n1:node)(n2:node): bool =
N.equal (find_ n1) (find_ n2)
let[@inline] find _ n = find_ n
(* print full state *)
let pp_full out (cc:t) : unit =
let pp_next out n =
Fmt.fprintf out "@ :next %a" N.pp n.n_next in
let pp_root out n =
if is_root_ n then Fmt.string out " :is-root" else Fmt.fprintf out "@ :root %a" N.pp n.n_root in
let pp_expl out n = match n.n_expl with
| FL_none -> ()
| FL_some e ->
Fmt.fprintf out " (@[:forest %a :expl %a@])" N.pp e.next Expl.pp e.expl
in
let pp_n out n =
Fmt.fprintf out "(@[%a%a%a%a@])" T.pp n.n_term pp_root n pp_next n pp_expl n
and pp_sig_e out (s,n) =
Fmt.fprintf out "(@[<1>%a@ ~~> %a%a@])" Signature.pp s N.pp n pp_root n
in
Fmt.fprintf out
"(@[@{<yellow>cc.state@}@ (@[<hv>:nodes@ %a@])@ (@[<hv>:sig-tbl@ %a@])@])"
(Util.pp_seq ~sep:" " pp_n) (T_tbl.values cc.tbl)
(Util.pp_seq ~sep:" " pp_sig_e) (Sig_tbl.to_seq cc.signatures_tbl)
(* compute up-to-date signature *)
let update_sig (s:signature) : Signature.t =
CC_types.map_view s
~f_f:(fun x->x)
~f_t:find_
~f_ts:(List.map find_)
(* find whether the given (parent) term corresponds to some signature
in [signatures_] *)
let[@inline] find_signature cc (s:signature) : repr option =
Sig_tbl.get cc.signatures_tbl s
let add_signature cc (s:signature) (n:node) : unit =
(* add, but only if not present already *)
match Sig_tbl.find cc.signatures_tbl s with
| exception Not_found ->
Log.debugf 15
(fun k->k "(@[cc.add-sig@ %a@ ~~> %a@])" Signature.pp s N.pp n);
on_backtrack cc (fun () -> Sig_tbl.remove cc.signatures_tbl s);
Sig_tbl.add cc.signatures_tbl s n;
| r' ->
assert (same_class n r');
()
let push_pending cc t : unit =
if not @@ N.get_field field_is_pending t then (
Log.debugf 5 (fun k->k "(@[<hv1>cc.push_pending@ %a@])" N.pp t);
N.set_field field_is_pending true t;
Vec.push cc.pending t
)
let push_combine cc t u e : unit =
Log.debugf 5
(fun k->k "(@[<hv1>cc.push_combine@ %a ~@ %a@ :expl %a@])"
N.pp t N.pp u Expl.pp e);
Vec.push cc.combine @@ CT_merge (t,u,e)
(* re-root the explanation tree of the equivalence class of [n]
so that it points to [n].
postcondition: [n.n_expl = None] *)
let rec reroot_expl (cc:t) (n:node): unit =
let old_expl = n.n_expl in
begin match old_expl with
| FL_none -> () (* already root *)
| FL_some {next=u; expl=e_n_u} ->
reroot_expl cc u;
u.n_expl <- FL_some {next=n; expl=e_n_u};
n.n_expl <- FL_none;
end
let raise_conflict (cc:t) (acts:sat_actions) (e:conflict): _ =
(* clear tasks queue *)
Vec.iter (N.set_field field_is_pending false) cc.pending;
Vec.clear cc.pending;
Vec.clear cc.combine;
let c = List.rev_map A.Lit.neg e in
acts.Msat.acts_raise_conflict c A.Proof.default
let[@inline] all_classes cc : repr Sequence.t =
T_tbl.values cc.tbl
|> Sequence.filter is_root_
(* TODO: use markers and lockstep iteration instead *)
(* distance from [t] to its root in the proof forest *)
let[@inline][@unroll 2] rec distance_to_root (n:node): int = match n.n_expl with
| FL_none -> 0
| FL_some {next=t'; _} -> 1 + distance_to_root t'
(* TODO: bool flag on nodes + stepwise progress + cleanup *)
(* find the closest common ancestor of [a] and [b] in the proof forest *)
let find_common_ancestor (a:node) (b:node) : node =
let d_a = distance_to_root a in
let d_b = distance_to_root b in
(* drop [n] nodes in the path from [t] to its root *)
let rec drop_ n t =
if n=0 then t
else match t.n_expl with
| FL_none -> assert false
| FL_some {next=t'; _} -> drop_ (n-1) t'
in
(* reduce to the problem where [a] and [b] have the same distance to root *)
let a, b =
if d_a > d_b then drop_ (d_a-d_b) a, b
else if d_a < d_b then a, drop_ (d_b-d_a) b
else a, b
in
(* traverse stepwise until a==b *)
let rec aux_same_dist a b =
if a==b then a
else match a.n_expl, b.n_expl with
| FL_none, _ | _, FL_none -> assert false
| FL_some {next=a'; _}, FL_some {next=b'; _} -> aux_same_dist a' b'
in
aux_same_dist a b
let[@inline] ps_add_obligation (cc:t) a b = Vec.push cc.ps_queue (a,b)
let[@inline] ps_add_lit ps l = ps.ps_lits <- l :: ps.ps_lits
(* TODO: remove *)
let ps_clear (cc:t) =
cc.ps_lits <- [];
Vec.clear cc.ps_queue;
()
let decompose_explain cc (e:explanation): unit =
Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e);
begin match e with
| E_reduction -> ()
| E_lit lit -> ps_add_lit cc lit
| E_lits l -> List.iter (ps_add_lit cc) l
| E_merges l ->
(* need to explain each merge in [l] *)
List.iter (fun (t,u) -> ps_add_obligation cc t u) l
end
(* explain why [a = parent_a], where [a -> ... -> parent_a] in the
proof forest *)
let rec explain_along_path ps (a:node) (parent_a:node) : unit =
if a!=parent_a then (
match a.n_expl with
| FL_none -> assert false
| FL_some {next=next_a; expl=e_a_b} ->
decompose_explain ps e_a_b;
(* now prove [next_a = parent_a] *)
explain_along_path ps next_a parent_a
)
(* find explanation *)
let explain_loop (cc : t) : lit list =
while not (Vec.is_empty cc.ps_queue) do
let a, b = Vec.pop cc.ps_queue in
Log.debugf 5
(fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b);
assert (N.equal (find_ a) (find_ b));
let c = find_common_ancestor a b in
explain_along_path cc a c;
explain_along_path cc b c;
done;
cc.ps_lits
(* TODO: do not use ps_lits anymore *)
let explain_eq_n ?(init=[]) cc (n1:node) (n2:node) : lit list =
ps_clear cc;
cc.ps_lits <- init;
ps_add_obligation cc n1 n2;
explain_loop cc
let explain_unfold ?(init=[]) cc (e:explanation) : lit list =
ps_clear cc;
cc.ps_lits <- init;
decompose_explain cc e;
explain_loop cc
(* add [tag] to [n], indicating that [n] is distinct from all the other
nodes tagged with [tag]
precond: [n] is a representative *)
let add_tag_n cc (n:node) (tag:int) (expl:explanation) : unit =
assert (is_root_ n);
if not (Util.Int_map.mem tag n.n_tags) then (
on_backtrack cc
(fun () -> n.n_tags <- Util.Int_map.remove tag n.n_tags);
n.n_tags <- Util.Int_map.add tag (n,expl) n.n_tags;
)
(* add a term *)
let [@inline] rec add_term_rec_ cc t : node =
try T_tbl.find cc.tbl t
with Not_found -> add_new_term_ cc t
(* add [t] to [cc] when not present already *)
and add_new_term_ cc (t:term) : node =
assert (not @@ mem cc t);
Log.debugf 15 (fun k->k "(@[cc.add-term@ %a@])" T.pp t);
let n = N.make t in
(* register sub-terms, add [t] to their parent list, and return the
corresponding initial signature *)
let sig0 = compute_sig0 cc n in
n.n_sig0 <- sig0;
(* remove term when we backtrack *)
on_backtrack cc
(fun () ->
Log.debugf 15 (fun k->k "(@[cc.remove-term@ %a@])" T.pp t);
T_tbl.remove cc.tbl t);
(* add term to the table *)
T_tbl.add cc.tbl t n;
if CCOpt.is_some sig0 then (
(* [n] might be merged with other equiv classes *)
push_pending cc n;
);
n
(* compute the initial signature of the given node *)
and compute_sig0 (self:t) (n:node) : Signature.t option =
(* add sub-term to [cc], and register [n] to its parents *)
let deref_sub (u:term) : node =
let sub = add_term_rec_ self u in
(* add [n] to [sub.root]'s parent list *)
begin
let sub = find_ sub in
let old_parents = sub.n_parents in
on_backtrack self (fun () -> sub.n_parents <- old_parents);
sub.n_parents <- Bag.cons n sub.n_parents;
end;
sub
in
let[@inline] return x = Some x in
match T.cc_view n.n_term with
| Bool _ | Opaque _ -> None
| Eq (a,b) ->
let a = deref_sub a in
let b = deref_sub b in
return @@ Eq (a,b)
| App_fun (f, args) ->
let args = args |> Sequence.map deref_sub |> Sequence.to_list in
if args<>[] then (
return @@ App_fun (f, args)
) else None
| App_ho (f, args) ->
let args = args |> Sequence.map deref_sub |> Sequence.to_list in
return @@ App_ho (deref_sub f, args)
| If (a,b,c) ->
return @@ If (deref_sub a, deref_sub b, deref_sub c)
let[@inline] add_term cc t : node = add_term_rec_ cc t
let[@inline] add_term' cc t : unit = ignore (add_term_rec_ cc t : node)
let set_as_lit cc (n:node) (lit:lit) : unit =
match n.n_as_lit with
| Some _ -> ()
| None ->
Log.debugf 15 (fun k->k "(@[cc.set-as-lit@ %a@ %a@])" N.pp n A.Lit.pp lit);
on_backtrack cc (fun () -> n.n_as_lit <- None);
n.n_as_lit <- Some lit
(* Checks if [ra] and [~into] have compatible normal forms and can
be merged w.r.t. the theories.
Side effect: also pushes sub-tasks *)
let notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit =
assert (is_root_ rb);
match cc.on_merge with
| Some f -> f ra rb e
| None -> ()
let[@inline] n_is_bool (self:t) n : bool =
N.equal n (true_ self) || N.equal n (false_ self)
(* main CC algo: add terms from [pending] to the signature table,
check for collisions *)
let rec update_tasks (cc:t) (acts:sat_actions) : unit =
while not (Vec.is_empty cc.pending && Vec.is_empty cc.combine) do
while not @@ Vec.is_empty cc.pending do
task_pending_ cc (Vec.pop cc.pending);
done;
while not @@ Vec.is_empty cc.combine do
task_combine_ cc acts (Vec.pop cc.combine);
done;
done
and task_pending_ cc (n:node) : unit =
N.set_field field_is_pending false n;
(* check if some parent collided *)
begin match n.n_sig0 with
| None -> () (* no-op *)
| Some (Eq (a,b)) ->
(* if [a=b] is now true, merge [(a=b)] and [true] *)
if same_class a b then (
let expl = Expl.mk_merges [(a,b)] in
push_combine cc n (true_ cc) expl
)
| Some s0 ->
(* update the signature by using [find] on each sub-node *)
let s = update_sig s0 in
match find_signature cc s with
| None ->
(* add to the signature table [sig(n) --> n] *)
add_signature cc s n
| Some u when n == u -> ()
| Some u ->
(* [t1] and [t2] must be applications of the same symbol to
arguments that are pairwise equal *)
assert (n != u);
let expl =
match n.n_sig0, u.n_sig0 with
| Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) ->
assert (Fun.equal f1 f2);
assert (List.length a1 = List.length a2);
(* TODO: just use "congruence" as explanation *)
Expl.mk_merges @@ List.combine a1 a2
| Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) ->
assert (List.length a1 = List.length a2);
(* TODO: just use "congruence" as explanation *)
Expl.mk_merges @@ (f1,f2)::List.combine a1 a2
| Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) ->
Expl.mk_merges @@ [a1,a2; b1,b2; c1,c2]
| _
-> assert false
in
push_combine cc n u expl
(* FIXME: when to actually evaluate?
eval_pending cc;
*)
end
and[@inline] task_combine_ cc acts = function
| CT_merge (a,b,e_ab) -> task_merge_ cc acts a b e_ab
| CT_distinct (l,tag,e) -> task_distinct_ cc acts l tag e
(* main CC algo: merge equivalence classes in [st.combine].
@raise Exn_unsat if merge fails *)
and task_merge_ cc acts a b e_ab : unit =
let ra = find_ a in
let rb = find_ b in
if not @@ N.equal ra rb then (
assert (is_root_ ra);
assert (is_root_ rb);
(* check we're not merging [true] and [false] *)
if (N.equal ra (true_ cc) && N.equal rb (false_ cc)) ||
(N.equal rb (true_ cc) && N.equal ra (false_ cc)) then (
Log.debugf 5
(fun k->k "(@[<hv>cc.merge.true_false_conflict@ @[:r1 %a@]@ @[:r2 %a@]@ :e_ab %a@])"
N.pp ra N.pp rb Expl.pp e_ab);
let lits = explain_unfold cc e_ab in
let lits = explain_eq_n ~init:lits cc a ra in
let lits = explain_eq_n ~init:lits cc b rb in
raise_conflict cc acts lits
);
(* We will merge [r_from] into [r_into].
we try to ensure that [size ra <= size rb] in general, but always
keep values as representative *)
let r_from, r_into =
if n_is_bool cc ra then rb, ra
else if n_is_bool cc rb then ra, rb
else if size_ ra > size_ rb then rb, ra
else ra, rb
in
(* TODO: instead call micro theories, including "distinct" *)
(* update set of tags the new node cannot be equal to *)
let new_tags =
Util.Int_map.union
(fun _i (n1,e1) (n2,e2) ->
(* both maps contain same tag [_i]. conflict clause:
[e1 & e2 & e_ab] impossible *)
Log.debugf 5
(fun k->k "(@[<hv>cc.merge.distinct_conflict@ :tag %d@ \
@[:r1 %a@ :e1 %a@]@ @[:r2 %a@ :e2 %a@]@ :e_ab %a@])"
_i N.pp n1 Expl.pp e1
N.pp n2 Expl.pp e2 Expl.pp e_ab);
let lits = explain_unfold cc e1 in
let lits = explain_unfold ~init:lits cc e2 in
let lits = explain_unfold ~init:lits cc e_ab in
let lits = explain_eq_n ~init:lits cc a n1 in
let lits = explain_eq_n ~init:lits cc b n2 in
raise_conflict cc acts lits)
ra.n_tags rb.n_tags
in
(* when merging terms with [true] or [false], possibly propagate them to SAT *)
let merge_bool r1 t1 r2 t2 =
if N.equal r1 (true_ cc) then (
propagate_bools cc acts r2 t2 r1 t1 e_ab true
) else if N.equal r1 (false_ cc) then (
propagate_bools cc acts r2 t2 r1 t1 e_ab false
)
in
merge_bool ra a rb b;
merge_bool rb b ra a;
(* perform [union r_from r_into] *)
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
(* TODO: only iterate on parents of [rb] *)
(* TODO: [ra.parents <- ra.parent ++ rb.parents] *)
begin
(* for each node in [r_from]'s class:
- make it point to [r_into]
- push it into [st.pending] *)
iter_class_ r_from
(fun u ->
assert (u.n_root == r_from);
on_backtrack cc (fun () -> u.n_root <- r_from);
u.n_root <- r_into;
Bag.to_seq u.n_parents
(fun parent -> push_pending cc parent));
(* now merge the classes *)
let r_into_old_tags = r_into.n_tags in
let r_into_old_next = r_into.n_next in
let r_from_old_next = r_from.n_next in
on_backtrack cc
(fun () ->
Log.debugf 15
(fun k->k "(@[cc.undo_merge@ :from %a :into %a@])"
N.pp r_from N.pp r_into);
r_into.n_next <- r_into_old_next;
r_from.n_next <- r_from_old_next;
r_into.n_tags <- r_into_old_tags);
r_into.n_tags <- new_tags;
(* swap [into.next] and [from.next], merging the classes *)
r_into.n_next <- r_from_old_next;
r_from.n_next <- r_into_old_next;
end;
(* update explanations (a -> b), arbitrarily.
Note that here we merge the classes by adding a bridge between [a]
and [b], not their roots. *)
begin
reroot_expl cc a;
assert (a.n_expl = FL_none);
(* on backtracking, link may be inverted, but we delete the one
that bridges between [a] and [b] *)
on_backtrack cc
(fun () -> match a.n_expl, b.n_expl with
| FL_some e, _ when N.equal e.next b -> a.n_expl <- FL_none
| _, FL_some e when N.equal e.next a -> b.n_expl <- FL_none
| _ -> assert false);
a.n_expl <- FL_some {next=b; expl=e_ab};
end;
(* notify listeners of the merge *)
notify_merge cc r_from ~into:r_into e_ab;
)
and task_distinct_ cc acts (l:node list) tag expl : unit =
let l = List.map (fun n -> n, find_ n) l in
let coll =
Sequence.diagonal_l l
|> Sequence.find_pred (fun ((_,r1),(_,r2)) -> N.equal r1 r2)
in
begin match coll with
| Some ((n1,_r1),(n2,_r2)) ->
(* two classes are already equal *)
Log.debugf 5
(fun k->k "(@[cc.distinct.conflict@ %a = %a@ :expl %a@])" N.pp n1 N.pp
n2 Expl.pp expl);
let lits = explain_unfold cc expl in
raise_conflict cc acts lits
| None ->
(* put a tag on all equivalence classes, that will make their merge fail *)
List.iter (fun (_,n) -> add_tag_n cc n tag expl) l
end
(* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1]
in the equiv class of [r1] that is a known literal back to the SAT solver
and which is not the one initially merged.
We can explain the propagation with [u1 = t1 =e= t2 = r2==bool] *)
and propagate_bools cc acts r1 t1 r2 t2 (e_12:explanation) sign : unit =
(* explanation for [t1 =e= t2 = r2] *)
let half_expl = lazy (
let expl = explain_unfold cc e_12 in
explain_eq_n ~init:expl cc r2 t2
) in
iter_class_ r1
(fun u1 ->
(* propagate if:
- [u1] is a proper literal
- [t2 != r2], because that can only happen
after an explicit merge (no way to obtain that by propagation)
*)
match N.as_lit u1 with
| Some lit when not (N.equal r2 t2) ->
let lit = if sign then lit else A.Lit.neg lit in (* apply sign *)
Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" A.Lit.pp lit);
(* complete explanation with the [u1=t1] chunk *)
let expl = explain_eq_n ~init:(Lazy.force half_expl) cc u1 t1 in
let reason = Msat.Consequence (expl, A.Proof.default) in
acts.Msat.acts_propagate lit reason
| _ -> ())
let check_invariants_ (cc:t) =
Log.debug 5 "(cc.check-invariants)";
Log.debugf 15 (fun k-> k "%a" pp_full cc);
assert (T.equal (T.bool cc.tst true) (true_ cc).n_term);
assert (T.equal (T.bool cc.tst false) (false_ cc).n_term);
assert (not @@ same_class (true_ cc) (false_ cc));
assert (Vec.is_empty cc.combine);
assert (Vec.is_empty cc.pending);
(* check that subterms are internalized *)
T_tbl.iter
(fun t n ->
assert (T.equal t n.n_term);
assert (not @@ N.get_field field_is_pending n);
assert (N.equal n.n_root n.n_next.n_root);
(* check proper signature.
note that some signatures in the sig table can be obsolete (they
were not removed) but there must be a valid, up-to-date signature for
each term *)
begin match CCOpt.map update_sig n.n_sig0 with
| None -> ()
| Some s ->
Log.debugf 15 (fun k->k "(@[cc.check-sig@ %a@ :sig %a@])" T.pp t Signature.pp s);
(* add, but only if not present already *)
begin match Sig_tbl.find cc.signatures_tbl s with
| exception Not_found -> assert false
| repr_s -> assert (same_class n repr_s)
end
end;
)
cc.tbl;
()
let[@inline] check_invariants (cc:t) : unit =
if Util._CHECK_INVARIANTS then check_invariants_ cc
let add_seq cc seq =
seq (fun t -> ignore @@ add_term_rec_ cc t);
()
let[@inline] push_level (self:t) : unit =
Backtrack_stack.push_level self.undo
let pop_levels (self:t) n : unit =
Vec.iter (N.set_field field_is_pending false) self.pending;
Vec.clear self.pending;
Vec.clear self.combine;
Log.debugf 15
(fun k->k "(@[cc.pop-levels %d@ :n-lvls %d@])" n (Backtrack_stack.n_levels self.undo));
Backtrack_stack.pop_levels self.undo n ~f:(fun f -> f());
()
(* assert that this boolean literal holds.
if a lit is [= a b], merge [a] and [b];
if it's [distinct a1an], make them distinct, etc. etc. *)
let assert_lit cc lit : unit =
let t = A.Lit.term lit in
Log.debugf 5 (fun k->k "(@[cc.assert_lit@ %a@])" A.Lit.pp lit);
let sign = A.Lit.sign lit in
begin match T.cc_view t with
| Eq (a,b) when sign ->
(* merge [a] and [b] *)
let a = add_term cc a in
let b = add_term cc b in
push_combine cc a b (Expl.mk_lit lit)
| _ ->
(* equate t and true/false *)
let rhs = if sign then true_ cc else false_ cc in
let n = add_term cc t in
(* TODO: ensure that this is O(1).
basically, just have [n] point to true/false and thus acquire
the corresponding value, so its superterms (like [ite]) can evaluate
properly *)
push_combine cc n rhs (Expl.mk_lit lit)
end
let[@inline] assert_lits cc lits : unit =
Sequence.iter (assert_lit cc) lits
let assert_eq cc t1 t2 (e:lit list) : unit =
let expl = Expl.mk_lits e in
let n1 = add_term cc t1 in
let n2 = add_term cc t2 in
push_combine cc n1 n2 expl
let assert_distinct cc (l:term list) ~neq (lit:lit) : unit =
assert (match l with[] | [_] -> false | _ -> true);
assert false
(* FIXME
let tag = Term.id neq in
Log.debugf 5
(fun k->k "(@[cc.assert_distinct@ (@[%a@])@ :tag %d@])" (Util.pp_list Term.pp) l tag);
let l = List.map (add cc) l in
Vec.push cc.combine @@ CT_distinct (l, tag, Expl.lit lit)
*)
let create ?on_merge ?(size=`Big) (tst:term_state) : t =
let size = match size with `Small -> 128 | `Big -> 2048 in
let rec cc = {
tst;
tbl = T_tbl.create size;
signatures_tbl = Sig_tbl.create size;
on_merge;
pending=Vec.create();
combine=Vec.create();
ps_lits=[];
undo=Backtrack_stack.create();
ps_queue=Vec.create();
true_;
false_;
} and true_ = lazy (
add_term cc (T.bool tst true)
) and false_ = lazy (
add_term cc (T.bool tst false)
)
in
ignore (Lazy.force true_ : node);
ignore (Lazy.force false_ : node);
cc
let[@inline] find_t cc t : repr =
let n = T_tbl.find cc.tbl t in
find_ n
let[@inline] check cc acts : unit =
Log.debug 5 "(cc.check)";
update_tasks cc acts
(* model: map each uninterpreted equiv class to some ID *)
let mk_model (cc:t) (m:A.Model.t) : A.Model.t =
let module Model = A.Model in
let module Value = A.Value in
Log.debugf 15 (fun k->k "(@[cc.mk-model@ %a@])" pp_full cc);
let t_tbl = N_tbl.create 32 in
(* populate [repr -> value] table *)
T_tbl.values cc.tbl
(fun r ->
if is_root_ r then (
(* find a value in the class, if any *)
let v =
iter_class_ r
|> Sequence.find_map (fun n -> Model.eval m n.n_term)
in
let v = match v with
| Some v -> v
| None ->
if same_class r (true_ cc) then Value.true_
else if same_class r (false_ cc) then Value.false_
else Value.fresh r.n_term
in
N_tbl.add t_tbl r v
));
(* now map every term to its representative's value *)
let pairs =
T_tbl.values cc.tbl
|> Sequence.map
(fun n ->
let r = find_ n in
let v =
try N_tbl.find t_tbl r
with Not_found ->
Error.errorf "didn't allocate a value for repr %a" N.pp r
in
n.n_term, v)
in
let m = Sequence.fold (fun m (t,v) -> Model.add t v m) m pairs in
Log.debugf 5 (fun k->k "(@[cc.model@ %a@])" Model.pp m);
m
end

View file

@ -0,0 +1,14 @@
(** {2 Congruence Closure} *)
module type ARG = Congruence_closure_intf.ARG
module type S = Congruence_closure_intf.S
type payload = Congruence_closure_intf.payload = ..
module Make(A: ARG)
: S with type term = A.Term.t
and type lit = A.Lit.t
and type fun_ = A.Fun.t
and type term_state = A.Term.state
and type proof = A.Proof.t
and type model = A.Model.t

View file

@ -0,0 +1,136 @@
module type ARG = CC_types.FULL
(** Theory-extensible payloads in the equivalence classes *)
type payload = ..
module type S = sig
type term_state
type term
type fun_
type lit
type proof
type model
(** Actions available to the theory *)
type sat_actions = (Msat.void, lit, Msat.void, proof) Msat.acts
type t
(** Global state of the congruence closure *)
(** An equivalence class is a set of terms that are currently equal
in the partial model built by the solver.
The class is represented by a collection of nodes, one of which is
distinguished and is called the "representative".
All information pertaining to the whole equivalence class is stored
in this representative's node.
When two classes become equal (are "merged"), one of the two
representatives is picked as the representative of the new class.
The new class contains the union of the two old classes' nodes.
We also allow theories to store additional information in the
representative. This information can be used when two classes are
merged, to detect conflicts and solve equations à la Shostak.
*)
module N : sig
type t
val term : t -> term
val equal : t -> t -> bool
val hash : t -> int
val pp : t Fmt.printer
type nonrec payload = payload = ..
val payload_find: f:(payload -> 'a option) -> t -> 'a option
val payload_pred: f:(payload -> bool) -> t -> bool
val set_payload : ?can_erase:(payload -> bool) -> t -> payload -> unit
(** Add given payload
@param can_erase if provided, checks whether an existing value
is to be replaced instead of adding a new entry *)
end
module Expl : sig
type t
val pp : t Fmt.printer
end
type node = N.t
(** A node of the congruence closure *)
type repr = N.t
(** Node that is currently a representative *)
type explanation = Expl.t
type conflict = lit list
(* TODO micro theories as parameters *)
val create :
?on_merge:(repr -> repr -> explanation -> unit) ->
?size:[`Small | `Big] ->
term_state ->
t
(** Create a new congruence closure.
@param on_merge callback to be called on every merge
*)
val find : t -> node -> repr
(** Current representative *)
val add_term : t -> term -> node
(** Add the term to the congruence closure, if not present already.
Will be backtracked. *)
val set_as_lit : t -> N.t -> lit -> unit
(** map the given node to a literal. *)
val add_term' : t -> term -> unit
(** Same as {!add_term} but ignore the result *)
val find_t : t -> term -> repr
(** Current representative of the term.
@raise Not_found if the term is not already {!add}-ed. *)
val add_seq : t -> term Sequence.t -> unit
(** Add a sequence of terms to the congruence closure *)
val all_classes : t -> repr Sequence.t
(** All current classes *)
val assert_lit : t -> lit -> unit
(** Given a literal, assume it in the congruence closure and propagate
its consequences. Will be backtracked. *)
val assert_lits : t -> lit Sequence.t -> unit
val assert_eq : t -> term -> term -> lit list -> unit
(** merge the given terms with some explanations *)
val assert_distinct : t -> term list -> neq:term -> lit -> unit
(** [assert_distinct l ~neq:u e] asserts all elements of [l] are distinct
because [lit] is true
precond: [u = distinct l] *)
val check : t -> sat_actions -> unit
(** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc.
Will use the [sat_actions] to propagate literals, declare conflicts, etc. *)
val push_level : t -> unit
val pop_levels : t -> int -> unit
val mk_model : t -> model -> model
(** Enrich a model by mapping terms to their representative's value,
if any. Otherwise map the representative to a fresh value *)
(**/**)
val check_invariants : t -> unit
val pp_full : t Fmt.printer
(**/**)
end

View file

@ -1,23 +1,34 @@
module H = CCHash
type ('f, 't, 'ts) view = ('f, 't, 'ts) Mini_cc_intf.view =
| Bool of bool
| App of 'f * 'ts
| If of 't * 't * 't
type res = Mini_cc_intf.res =
type res =
| Sat
| Unsat
module type ARG = Mini_cc_intf.ARG
module type S = Mini_cc_intf.S
module type TERM = CC_types.TERM
module type S = sig
type term
type fun_
type term_state
type t
val create : term_state -> t
val add_lit : t -> term -> bool -> unit
val distinct : t -> term list -> unit
val check : t -> res
end
module Make(A: TERM) = struct
open CC_types
module Make(A: ARG) = struct
module Fun = A.Fun
module T = A.Term
type fun_ = A.Fun.t
type term = T.t
type term_state = A.Term.state
module T_tbl = CCHashtbl.Make(T)
@ -65,49 +76,77 @@ module Make(A: ARG) = struct
let equal (s1:t) s2 : bool =
match s1, s2 with
| Bool b1, Bool b2 -> b1=b2
| App (f1,[]), App (f2,[]) -> Fun.equal f1 f2
| App (f1,l1), App (f2,l2) ->
| App_fun (f1,[]), App_fun (f2,[]) -> Fun.equal f1 f2
| App_fun (f1,l1), App_fun (f2,l2) ->
Fun.equal f1 f2 && CCList.equal Node.equal l1 l2
| App_ho (f1,l1), App_ho (f2,l2) ->
Node.equal f1 f2 && CCList.equal Node.equal l1 l2
| If (a1,b1,c1), If (a2,b2,c2) ->
Node.equal a1 a2 && Node.equal b1 b2 && Node.equal c1 c2
| Bool _, _ | App _, _ | If _, _
| Eq (a1,b1), Eq (a2,b2) ->
Node.equal a1 a2 && Node.equal b1 b2
| Opaque u1, Opaque u2 -> Node.equal u1 u2
| Bool _, _ | App_fun _, _ | App_ho _, _ | If _, _
| Eq _, _ | Opaque _, _
-> false
let hash (s:t) : int =
let module H = CCHash in
match s with
| Bool b -> H.combine2 10 (H.bool b)
| App (f, l) -> H.combine3 20 (Fun.hash f) (H.list Node.hash l)
| If (a,b,c) -> H.combine4 30 (Node.hash a)(Node.hash b)(Node.hash c)
| App_fun (f, l) -> H.combine3 20 (Fun.hash f) (H.list Node.hash l)
| App_ho (f, l) -> H.combine3 30 (Node.hash f) (H.list Node.hash l)
| Eq (a,b) -> H.combine3 40 (Node.hash a) (Node.hash b)
| Opaque u -> H.combine2 50 (Node.hash u)
| If (a,b,c) -> H.combine4 60 (Node.hash a)(Node.hash b)(Node.hash c)
let pp out = function
| Bool b -> Fmt.bool out b
| App (f, []) -> Fun.pp out f
| App (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Fun.pp f (Util.pp_list Node.pp) l
| App_fun (f, []) -> Fun.pp out f
| App_fun (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Fun.pp f (Util.pp_list Node.pp) l
| App_ho (f, []) -> Node.pp out f
| App_ho (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Node.pp f (Util.pp_list Node.pp) l
| Opaque t -> Node.pp out t
| Eq (a,b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" Node.pp a Node.pp b
| If (a,b,c) -> Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" Node.pp a Node.pp b Node.pp c
end
module Sig_tbl = CCHashtbl.Make(Signature)
type t = {
mutable ok: bool; (* unsat? *)
tbl: node T_tbl.t;
sig_tbl: node Sig_tbl.t;
combine: (node * node) Vec.t;
pending: node Vec.t; (* refresh signature *)
distinct: node list ref Vec.t; (* disjoint sets *)
true_: node;
false_: node;
}
let create() : t =
{ tbl= T_tbl.create 128;
let create tst : t =
let true_ = T.bool tst true in
let false_ = T.bool tst false in
let self = {
ok=true;
tbl= T_tbl.create 128;
sig_tbl=Sig_tbl.create 128;
combine=Vec.create();
pending=Vec.create();
distinct=Vec.create();
}
true_=Node.make true_;
false_=Node.make false_;
} in
T_tbl.add self.tbl true_ self.true_;
T_tbl.add self.tbl false_ self.false_;
self
let sub_ t k : unit =
match T.view t with
| Bool _ -> ()
| App (_, args) -> args k
match T.cc_view t with
| Bool _ | Opaque _ -> ()
| App_fun (_, args) -> args k
| App_ho (f, args) -> k f; args k
| Eq (a,b) -> k a; k b
| If(a,b,c) -> k a; k b; k c
let rec add_t (self:t) (t:term) : node =
@ -152,8 +191,37 @@ module Make(A: ARG) = struct
if has_dups !r then raise_notrace E_unsat)
self.distinct
let compute_sig (self:t) (n:node) : Signature.t option =
let[@inline] return x = Some x in
match T.cc_view n.n_t with
| Bool _ | Opaque _ -> None
| Eq (a,b) ->
let a = find_t_ self a in
let b = find_t_ self b in
return @@ Eq (a,b)
| App_fun (f, args) ->
let args = args |> Sequence.map (find_t_ self) |> Sequence.to_list in
if args<>[] then (
return @@ App_fun (f, args)
) else None
| App_ho (f, args) ->
let args = args |> Sequence.map (find_t_ self) |> Sequence.to_list in
return @@ App_ho (find_t_ self f, args)
| If (a,b,c) ->
return @@ If(find_t_ self a, find_t_ self b, find_t_ self c)
let update_sig_ (self:t) (n: node) : unit =
let aux s =
match compute_sig self n with
| None -> ()
| Some (Eq (a,b)) ->
if Node.equal a b then (
(* reduce to [true] *)
let n2 = self.true_ in
Log.debugf 5
(fun k->k "(@[minicc.congruence-by-eq@ %a@ %a@])" Node.pp n Node.pp n2);
Vec.push self.combine (n,n2)
)
| Some s ->
Log.debugf 5 (fun k->k "(@[minicc.update-sig@ %a@])" Signature.pp s);
match Sig_tbl.find self.sig_tbl s with
| n2 when Node.equal n n2 -> ()
@ -164,23 +232,28 @@ module Make(A: ARG) = struct
Vec.push self.combine (n,n2)
| exception Not_found ->
Sig_tbl.add self.sig_tbl s n
in
match T.view n.n_t with
| Bool _ -> ()
| App (f, args) ->
let args = args |> Sequence.map (find_t_ self) |> Sequence.to_list in
aux @@ App (f, args)
| If (a,b,c) -> aux @@ If(find_t_ self a, find_t_ self b, find_t_ self c)
let[@inline] is_bool self n = Node.equal self.true_ n || Node.equal self.false_ n
(* merge the two classes *)
let merge_ self (n1,n2) : unit =
let n1 = find_ n1 in
let n2 = find_ n2 in
if not @@ Node.equal n1 n2 then (
(* merge into largest class *)
let n1, n2 = if Node.size n1 > Node.size n2 then n1, n2 else n2, n1 in
(* merge into largest class, or into a boolean *)
let n1, n2 =
if is_bool self n1 then n1, n2
else if is_bool self n2 then n2, n1
else if Node.size n1 > Node.size n2 then n1, n2
else n2, n1 in
Log.debugf 5 (fun k->k "(@[minicc.merge@ :into %a@ %a@])" Node.pp n1 Node.pp n2);
if is_bool self n1 && is_bool self n2 then (
Log.debugf 5 (fun k->k "(minicc.conflict.merge-true-false)");
self.ok <- false;
raise E_unsat
);
List.iter (Vec.push self.pending) n2.n_parents; (* will change signature *)
(* merge parent lists *)
@ -191,9 +264,13 @@ module Make(A: ARG) = struct
Node.iter_cls n2 (fun n -> n.n_root <- n1);
)
let check_ok_ self =
if not self.ok then raise_notrace E_unsat
(* fixpoint of the congruence closure *)
let fixpoint (self:t) : unit =
while not (Vec.is_empty self.pending && Vec.is_empty self.combine) do
check_ok_ self;
while not @@ Vec.is_empty self.pending do
update_sig_ self @@ Vec.pop self.pending
done;
@ -205,10 +282,17 @@ module Make(A: ARG) = struct
(* API *)
let merge (self:t) t1 t2 : unit =
let n1 = add_t self t1 in
let n2 = add_t self t2 in
Vec.push self.combine (n1,n2)
let add_lit (self:t) (p:T.t) (sign:bool) : unit =
match T.cc_view p with
| Eq (t1,t2) when sign ->
let n1 = add_t self t1 in
let n2 = add_t self t2 in
Vec.push self.combine (n1,n2)
| _ ->
(* just merge with true/false *)
let n = add_t self p in
let n2 = if sign then self.true_ else self.false_ in
Vec.push self.combine (n,n2)
let distinct (self:t) l =
begin match l with
@ -220,6 +304,8 @@ module Make(A: ARG) = struct
let check (self:t) : res =
try fixpoint self; Sat
with E_unsat -> Unsat
with E_unsat ->
self.ok <- false;
Unsat
end

36
src/cc/Mini_cc.mli Normal file
View file

@ -0,0 +1,36 @@
(** {1 Mini congruence closure}
This implementation is as simple as possible, and doesn't provide
backtracking, theories, or explanations.
It just decides the satisfiability of a set of (dis)equations.
*)
type res =
| Sat
| Unsat
module type TERM = CC_types.TERM
module type S = sig
type term
type fun_
type term_state
type t
val create : term_state -> t
val add_lit : t -> term -> bool -> unit
(** [add_lit cc p sign] asserts that [p=sign] *)
val distinct : t -> term list -> unit
(** [distinct cc l] asserts that all terms in [l] are distinct *)
val check : t -> res
end
module Make(A: TERM)
: S with type term = A.Term.t
and type fun_ = A.Fun.t
and type term_state = A.Term.state

23
src/cc/Sidekick_cc.ml Normal file
View file

@ -0,0 +1,23 @@
type ('f, 't, 'ts) view = ('f, 't, 'ts) CC_types.view =
| Bool of bool
| App_fun of 'f * 'ts
| App_ho of 't * 'ts
| If of 't * 't * 't
| Eq of 't * 't
| Opaque of 't (* do not enter *)
type payload = Congruence_closure.payload = ..
module CC_types = CC_types
(** Parameter for the congruence closure *)
module type TERM_LIT = CC_types.TERM_LIT
module type FULL = CC_types.FULL
module type S = Congruence_closure.S
module Mini_cc = Mini_cc
module Congruence_closure = Congruence_closure
module Make = Congruence_closure.Make

10
src/cc/dune Normal file
View file

@ -0,0 +1,10 @@
(library
(name Sidekick_cc)
(public_name sidekick.cc)
(libraries containers containers.data msat sequence sidekick.util)
(flags :standard -warn-error -a+8
-color always -safe-string -short-paths -open Sidekick_util)
(ocamlopt_flags :standard -O3 -color always
-unbox-closures -unbox-closures-factor 20))

18
src/smt/CC.ml Normal file
View file

@ -0,0 +1,18 @@
module Arg = struct
module Fun = Cst
module Term = Term
module Lit = Lit
module Value = Value
module Ty = Ty
module Model = Model
module Proof = struct
type t = Solver_types.proof
let pp = Solver_types.pp_proof
let default = Solver_types.Proof_default
end
end
include Sidekick_cc.Make(Arg)
module Mini_cc = Sidekick_cc.Mini_cc.Make(Arg)

13
src/smt/CC.mli Normal file
View file

@ -0,0 +1,13 @@
include Sidekick_cc.S
with type term = Term.t
and type model = Model.t
and type lit = Lit.t
and type fun_ = Cst.t
and type term_state = Term.state
and type proof = Solver_types.proof
module Mini_cc : Sidekick_cc.Mini_cc.S
with type term = Term.t
and type fun_ = Cst.t
and type term_state = Term.state

View file

@ -1,763 +0,0 @@
open Solver_types
module N = Eq_class
type node = N.t
type repr = N.t
type conflict = Theory.conflict
module T_arg = struct
module Fun = Cst
module Term = struct
include Term
let view = cc_view
end
end
module Mini_cc = Mini_cc.Make(T_arg)
(** A signature is a shallow term shape where immediate subterms
are representative *)
module Signature = struct
type t = node Term.view
include Term_cell.Make_eq(N)
end
module Sig_tbl = CCHashtbl.Make(Signature)
type explanation_thunk = explanation lazy_t
type combine_task =
| CT_merge of node * node * explanation_thunk
| CT_distinct of node list * int * explanation
type t = {
tst: Term.state;
tbl: node Term.Tbl.t;
(* internalization [term -> node] *)
signatures_tbl : node Sig_tbl.t;
(* map a signature to the corresponding node in some equivalence class.
A signature is a [term_cell] in which every immediate subterm
that participates in the congruence/evaluation relation
is normalized (i.e. is its own representative).
The critical property is that all members of an equivalence class
that have the same "shape" (including head symbol)
have the same signature *)
pending: node Vec.t;
combine: combine_task Vec.t;
undo: (unit -> unit) Backtrack_stack.t;
on_merge: (repr -> repr -> explanation -> unit) option;
mutable ps_lits: Lit.Set.t; (* TODO: thread it around instead? *)
(* proof state *)
ps_queue: (node*node) Vec.t;
(* pairs to explain *)
true_ : node lazy_t;
false_ : node lazy_t;
}
(* TODO: an additional union-find to keep track, for each term,
of the terms they are known to be equal to, according
to the current explanation. That allows not to prove some equality
several times.
See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *)
let[@inline] is_root_ (n:node) : bool = n.n_root == n
let[@inline] size_ (r:repr) = r.n_size
let[@inline] true_ cc = Lazy.force cc.true_
let[@inline] false_ cc = Lazy.force cc.false_
let[@inline] on_backtrack cc f : unit =
Backtrack_stack.push_if_nonzero_level cc.undo f
(* check if [t] is in the congruence closure.
Invariant: [in_cc t do_cc t => forall u subterm t, in_cc u] *)
let[@inline] mem (cc:t) (t:term): bool = Term.Tbl.mem cc.tbl t
(* find representative, recursively *)
let rec find_rec cc (n:node) : repr =
if n==n.n_root then (
n
) else (
(* TODO: path compression, assuming backtracking restores equiv classes
properly *)
let root = find_rec cc n.n_root in
root
)
(* traverse the equivalence class of [n] *)
let iter_class_ (n:node) : node Sequence.t =
fun yield ->
let rec aux u =
yield u;
if u.n_next != n then aux u.n_next
in
aux n
(* get term that should be there *)
let[@inline] get_ cc (t:term) : node =
try Term.Tbl.find cc.tbl t
with Not_found ->
Log.debugf 1 (fun k->k "(@[<hv1>cc.error@ :missing-term %a@])" Term.pp t);
assert false
(* non-recursive, inlinable function for [find] *)
let[@inline] find st (n:node) : repr =
if n == n.n_root then n else find_rec st n
let[@inline] find_tn cc (t:term) : repr = get_ cc t |> find cc
let[@inline] same_class cc (n1:node)(n2:node): bool =
N.equal (find cc n1) (find cc n2)
(* print full state *)
let pp_full out (cc:t) : unit =
let pp_next out n =
Fmt.fprintf out "@ :next %a" N.pp n.n_next in
let pp_root out n =
if is_root_ n then Fmt.string out " :is-root" else Fmt.fprintf out "@ :root %a" N.pp n.n_root in
let pp_expl out n = match n.n_expl with
| E_none -> ()
| E_some e ->
Fmt.fprintf out " (@[:forest %a :expl %a@])" N.pp e.next Explanation.pp e.expl
in
let pp_n out n =
Fmt.fprintf out "(@[%a%a%a%a@])" Term.pp n.n_term pp_root n pp_next n pp_expl n
and pp_sig_e out (s,n) =
Fmt.fprintf out "(@[<1>%a@ <--> %a%a@])" Signature.pp s N.pp n pp_root n
in
Fmt.fprintf out
"(@[@{<yellow>cc.state@}@ (@[<hv>:nodes@ %a@])@ (@[<hv>:sig@ %a@])@])"
(Util.pp_seq ~sep:" " pp_n) (Term.Tbl.values cc.tbl)
(Util.pp_seq ~sep:" " pp_sig_e) (Sig_tbl.to_seq cc.signatures_tbl)
(* compute signature *)
let signature cc (t:term): Signature.t option =
let find = find_tn cc in
begin match Term.view t with
| App_cst (_, a) when IArray.is_empty a -> None
| App_cst (c, _) when not @@ Cst.do_cc c -> None (* no CC *)
| App_cst (f, a) -> Some (App_cst (f, IArray.map find a)) (* FIXME: relevance? *)
| Bool _ | If _
-> None (* no congruence for these *)
end
(* find whether the given (parent) term corresponds to some signature
in [signatures_] *)
let find_by_signature cc (t:term) : repr option =
match signature cc t with
| None -> None
| Some s -> Sig_tbl.get cc.signatures_tbl s
let add_signature cc (n:node): unit =
match signature cc n.n_term with
| None -> ()
| Some s ->
(* add, but only if not present already *)
begin match Sig_tbl.find cc.signatures_tbl s with
| exception Not_found ->
Log.debugf 15
(fun k->k "(@[cc.add_sig@ %a@ <--> %a@])" Signature.pp s N.pp n);
on_backtrack cc (fun () -> Sig_tbl.remove cc.signatures_tbl s);
Sig_tbl.add cc.signatures_tbl s n;
| r' ->
assert (same_class cc n r');
end
let push_pending cc t : unit =
if not @@ N.get_field N.field_is_pending t then (
Log.debugf 5 (fun k->k "(@[<hv1>cc.push_pending@ %a@])" N.pp t);
N.set_field N.field_is_pending true t;
Vec.push cc.pending t
)
let push_combine cc t u e : unit =
Log.debugf 5
(fun k->k "(@[<hv1>cc.push_combine@ :t1 %a@ :t2 %a@ :expl %a@])"
N.pp t N.pp u Explanation.pp (Lazy.force e));
Vec.push cc.combine @@ CT_merge (t,u,e)
(* re-root the explanation tree of the equivalence class of [n]
so that it points to [n].
postcondition: [n.n_expl = None] *)
let rec reroot_expl (cc:t) (n:node): unit =
let old_expl = n.n_expl in
begin match old_expl with
| E_none -> () (* already root *)
| E_some {next=u; expl=e_n_u} ->
reroot_expl cc u;
u.n_expl <- E_some {next=n; expl=e_n_u};
n.n_expl <- E_none;
end
let raise_conflict (cc:t) (acts:sat_actions) (e:conflict): _ =
(* clear tasks queue *)
Vec.iter (N.set_field N.field_is_pending false) cc.pending;
Vec.clear cc.pending;
Vec.clear cc.combine;
let c = List.map Lit.neg e in
acts.Msat.acts_raise_conflict c Proof_default
let[@inline] all_classes cc : repr Sequence.t =
Term.Tbl.values cc.tbl
|> Sequence.filter is_root_
(* TODO: use markers and lockstep iteration instead *)
(* distance from [t] to its root in the proof forest *)
let[@inline][@unroll 2] rec distance_to_root (n:node): int = match n.n_expl with
| E_none -> 0
| E_some {next=t'; _} -> 1 + distance_to_root t'
(* TODO: bool flag on nodes + stepwise progress + cleanup *)
(* find the closest common ancestor of [a] and [b] in the proof forest *)
let find_common_ancestor (a:node) (b:node) : node =
let d_a = distance_to_root a in
let d_b = distance_to_root b in
(* drop [n] nodes in the path from [t] to its root *)
let rec drop_ n t =
if n=0 then t
else match t.n_expl with
| E_none -> assert false
| E_some {next=t'; _} -> drop_ (n-1) t'
in
(* reduce to the problem where [a] and [b] have the same distance to root *)
let a, b =
if d_a > d_b then drop_ (d_a-d_b) a, b
else if d_a < d_b then a, drop_ (d_b-d_a) b
else a, b
in
(* traverse stepwise until a==b *)
let rec aux_same_dist a b =
if a==b then a
else match a.n_expl, b.n_expl with
| E_none, _ | _, E_none -> assert false
| E_some {next=a'; _}, E_some {next=b'; _} -> aux_same_dist a' b'
in
aux_same_dist a b
let[@inline] ps_add_obligation (cc:t) a b = Vec.push cc.ps_queue (a,b)
let[@inline] ps_add_lit ps l = ps.ps_lits <- Lit.Set.add l ps.ps_lits
let ps_clear (cc:t) =
cc.ps_lits <- Lit.Set.empty;
Vec.clear cc.ps_queue;
()
let decompose_explain cc (e:explanation): unit =
Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Explanation.pp e);
begin match e with
| E_reduction -> ()
| E_lit lit -> ps_add_lit cc lit
| E_lits l -> List.iter (ps_add_lit cc) l
| E_merges l ->
(* need to explain each merge in [l] *)
IArray.iter (fun (t,u) -> ps_add_obligation cc t u) l
end
(* explain why [a = parent_a], where [a -> ... -> parent_a] in the
proof forest *)
let rec explain_along_path ps (a:node) (parent_a:node) : unit =
if a!=parent_a then (
match a.n_expl with
| E_none -> assert false
| E_some {next=next_a; expl=e_a_b} ->
decompose_explain ps e_a_b;
(* now prove [next_a = parent_a] *)
explain_along_path ps next_a parent_a
)
(* find explanation *)
let explain_loop (cc : t) : Lit.Set.t =
while not (Vec.is_empty cc.ps_queue) do
let a, b = Vec.pop cc.ps_queue in
Log.debugf 5
(fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b);
assert (N.equal (find cc a) (find cc b));
let c = find_common_ancestor a b in
explain_along_path cc a c;
explain_along_path cc b c;
done;
cc.ps_lits
(* TODO: do not use ps_lits anymore *)
let explain_eq_n ?(init=Lit.Set.empty) cc (n1:node) (n2:node) : Lit.Set.t =
ps_clear cc;
cc.ps_lits <- init;
ps_add_obligation cc n1 n2;
explain_loop cc
let explain_unfold ?(init=Lit.Set.empty) cc (e:explanation) : Lit.Set.t =
ps_clear cc;
cc.ps_lits <- init;
decompose_explain cc e;
explain_loop cc
(* add [tag] to [n], indicating that [n] is distinct from all the other
nodes tagged with [tag]
precond: [n] is a representative *)
let add_tag_n cc (n:node) (tag:int) (expl:explanation) : unit =
assert (is_root_ n);
if not (Util.Int_map.mem tag n.n_tags) then (
on_backtrack cc
(fun () -> n.n_tags <- Util.Int_map.remove tag n.n_tags);
n.n_tags <- Util.Int_map.add tag (n,expl) n.n_tags;
)
(* TODO: payload for set of tags *)
(* TODO: payload for mapping an equiv class to a set of literals, for bool prop *)
let relevant_subterms (t:Term.t) : Term.t Sequence.t =
fun yield ->
match t.term_view with
| App_cst (c, a) when Cst.do_cc c -> IArray.iter yield a
| Bool _ | App_cst _ -> ()
| If (a,b,c) ->
(* TODO: relevancy? only [a] needs be decided for now *)
yield a;
yield b;
yield c
(* Checks if [ra] and [~into] have compatible normal forms and can
be merged w.r.t. the theories.
Side effect: also pushes sub-tasks *)
let notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit =
assert (is_root_ rb);
match cc.on_merge with
| Some f -> f ra rb e
| None -> ()
(* main CC algo: add terms from [pending] to the signature table,
check for collisions *)
let rec update_tasks (cc:t) (acts:sat_actions) : unit =
while not (Vec.is_empty cc.pending && Vec.is_empty cc.combine) do
Vec.iter (task_pending_ cc) cc.pending;
Vec.clear cc.pending;
Vec.iter (task_combine_ cc acts) cc.combine;
Vec.clear cc.combine;
done
and task_pending_ cc n =
N.set_field N.field_is_pending false n;
(* check if some parent collided *)
begin match find_by_signature cc n.n_term with
| None ->
(* add to the signature table [sig(n) --> n] *)
add_signature cc n
| Some u when n == u -> ()
| Some u ->
(* [t1] and [t2] must be applications of the same symbol to
arguments that are pairwise equal *)
assert (n != u);
let expl = lazy (
match n.n_term.term_view, u.n_term.term_view with
| App_cst (f1, a1), App_cst (f2, a2) ->
assert (Cst.equal f1 f2);
assert (IArray.length a1 = IArray.length a2);
(* TODO: just use "congruence" as explanation *)
Explanation.mk_merges @@ IArray.map2 (fun u1 u2 -> add_term_rec_ cc u1, add_term_rec_ cc u2) a1 a2
| If _, _ | App_cst _, _ | Bool _, _
-> assert false
) in
push_combine cc n u expl
end;
(* TODO: evaluate [(= t u) := true] when [find t==find u] *)
(* FIXME: when to actually evaluate?
eval_pending cc;
*)
()
and[@inline] task_combine_ cc acts = function
| CT_merge (a,b,e_ab) -> task_merge_ cc acts a b e_ab
| CT_distinct (l,tag,e) -> task_distinct_ cc acts l tag e
(* main CC algo: merge equivalence classes in [st.combine].
@raise Exn_unsat if merge fails *)
and task_merge_ cc acts a b e_ab : unit =
let ra = find cc a in
let rb = find cc b in
if not @@ N.equal ra rb then (
assert (is_root_ ra);
assert (is_root_ rb);
let lazy e_ab = e_ab in
(* We will merge [r_from] into [r_into].
we try to ensure that [size ra <= size rb] in general, but always
keep values as representative *)
let r_from, r_into =
if Term.is_value ra.n_term then rb, ra
else if Term.is_value rb.n_term then ra, rb
else if size_ ra > size_ rb then rb, ra
else ra, rb
in
(* check we're not merging [true] and [false] *)
if (N.equal ra (true_ cc) && N.equal rb (false_ cc)) ||
(N.equal rb (true_ cc) && N.equal ra (false_ cc)) then (
Log.debugf 5
(fun k->k "(@[<hv>cc.merge.true_false_conflict@ @[:r1 %a@]@ @[:r2 %a@]@ :e_ab %a@])"
N.pp ra N.pp rb Explanation.pp e_ab);
let lits = explain_unfold cc e_ab in
let lits = explain_eq_n ~init:lits cc a ra in
let lits = explain_eq_n ~init:lits cc b rb in
raise_conflict cc acts @@ Lit.Set.elements lits
);
(* TODO: isntead call micro theories, including "distinct" *)
(* update set of tags the new node cannot be equal to *)
let new_tags =
Util.Int_map.union
(fun _i (n1,e1) (n2,e2) ->
(* both maps contain same tag [_i]. conflict clause:
[e1 & e2 & e_ab] impossible *)
Log.debugf 5
(fun k->k "(@[<hv>cc.merge.distinct_conflict@ :tag %d@ \
@[:r1 %a@ :e1 %a@]@ @[:r2 %a@ :e2 %a@]@ :e_ab %a@])"
_i N.pp n1 Explanation.pp e1
N.pp n2 Explanation.pp e2 Explanation.pp e_ab);
let lits = explain_unfold cc e1 in
let lits = explain_unfold ~init:lits cc e2 in
let lits = explain_unfold ~init:lits cc e_ab in
let lits = explain_eq_n ~init:lits cc a n1 in
let lits = explain_eq_n ~init:lits cc b n2 in
raise_conflict cc acts @@ Lit.Set.elements lits)
ra.n_tags rb.n_tags
in
(* when merging terms with [true] or [false], possibly propagate them to SAT *)
let merge_bool r1 t1 r2 t2 =
if N.equal r1 (true_ cc) then (
propagate_bools cc acts r2 t2 r1 t1 e_ab true
) else if N.equal r1 (false_ cc) then (
propagate_bools cc acts r2 t2 r1 t1 e_ab false
)
in
merge_bool ra a rb b;
merge_bool rb b ra a;
(* perform [union r_from r_into] *)
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
(* TODO: only iterate on parents of [rb] *)
(* TODO: [ra.parents <- ra.parent ++ rb.parents] *)
begin
(* for each node in [r_from]'s class:
- make it point to [r_into]
- push it into [st.pending] *)
iter_class_ r_from
(fun u ->
assert (u.n_root == r_from);
on_backtrack cc (fun () -> u.n_root <- r_from);
u.n_root <- r_into;
Bag.to_seq u.n_parents
(fun parent -> push_pending cc parent));
(* now merge the classes *)
let r_into_old_tags = r_into.n_tags in
let r_into_old_next = r_into.n_next in
let r_from_old_next = r_from.n_next in
on_backtrack cc
(fun () ->
Log.debugf 15
(fun k->k "(@[cc.undo_merge@ :from %a :into %a@])"
Term.pp r_from.n_term Term.pp r_into.n_term);
r_into.n_next <- r_into_old_next;
r_from.n_next <- r_from_old_next;
r_into.n_tags <- r_into_old_tags);
r_into.n_tags <- new_tags;
(* swap [into.next] and [from.next], merging the classes *)
r_into.n_next <- r_from_old_next;
r_from.n_next <- r_into_old_next;
end;
(* update explanations (a -> b), arbitrarily.
Note that here we merge the classes by adding a bridge between [a]
and [b], not their roots. *)
begin
reroot_expl cc a;
assert (a.n_expl = E_none);
(* on backtracking, link may be inverted, but we delete the one
that bridges between [a] and [b] *)
on_backtrack cc
(fun () -> match a.n_expl, b.n_expl with
| E_some e, _ when N.equal e.next b -> a.n_expl <- E_none
| _, E_some e when N.equal e.next a -> b.n_expl <- E_none
| _ -> assert false);
a.n_expl <- E_some {next=b; expl=e_ab};
end;
(* notify listeners of the merge *)
notify_merge cc r_from ~into:r_into e_ab;
)
and task_distinct_ cc acts (l:node list) tag expl : unit =
let l = List.map (fun n -> n, find cc n) l in
let coll =
Sequence.diagonal_l l
|> Sequence.find_pred (fun ((_,r1),(_,r2)) -> N.equal r1 r2)
in
begin match coll with
| Some ((n1,_r1),(n2,_r2)) ->
(* two classes are already equal *)
Log.debugf 5
(fun k->k "(@[cc.distinct.conflict@ %a = %a@ :expl %a@])" N.pp n1 N.pp
n2 Explanation.pp expl);
let lits = explain_unfold cc expl in
raise_conflict cc acts (Lit.Set.to_list lits)
| None ->
(* put a tag on all equivalence classes, that will make their merge fail *)
List.iter (fun (_,n) -> add_tag_n cc n tag expl) l
end
(* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1]
in the equiv class of [r1] that is a known literal back to the SAT solver
and which is not the one initially merged.
We can explain the propagation with [u1 = t1 =e= t2 = r2==bool] *)
and propagate_bools cc acts r1 t1 r2 t2 (e_12:explanation) sign : unit =
(* explanation for [t1 =e= t2 = r2] *)
let half_expl = lazy (
let expl = explain_unfold cc e_12 in
explain_eq_n ~init:expl cc r2 t2
) in
iter_class_ r1
(fun u1 ->
(* propagate if:
- [u1] is a proper literal
- [t2 != r2], because that can only happen
after an explicit merge (no way to obtain that by propagation)
*)
if N.get_field N.field_is_literal u1 && not (N.equal r2 t2) then (
let lit = Lit.atom ~sign u1.n_term in
Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" Lit.pp lit);
(* complete explanation with the [u1=t1] chunk *)
let expl = explain_eq_n ~init:(Lazy.force half_expl) cc u1 t1 in
let reason = Msat.Consequence (Lit.Set.to_list expl, Proof_default) in
acts.Msat.acts_propagate lit reason
))
(* add [t] to [cc] when not present already *)
and add_new_term_ cc (t:term) : node =
assert (not @@ mem cc t);
Log.debugf 15 (fun k->k "(@[cc.add-term@ %a@])" Term.pp t);
let n = N.make t in
(* how to add a subterm *)
let add_to_parents_of_sub_node (sub:node) : unit =
let sub = find cc sub in (* update the repr! *)
let old_parents = sub.n_parents in
on_backtrack cc (fun () -> sub.n_parents <- old_parents);
sub.n_parents <- Bag.cons n sub.n_parents;
in
(* add sub-term to [cc], and register [n] to its parents *)
let add_sub_t (u:term) : unit =
let n_u = add_term_rec_ cc u in
add_to_parents_of_sub_node n_u
in
(* register sub-terms, add [t] to their parent list *)
relevant_subterms t add_sub_t;
(* remove term when we backtrack *)
on_backtrack cc
(fun () ->
Log.debugf 15 (fun k->k "(@[cc.remove-term@ %a@])" Term.pp t);
Term.Tbl.remove cc.tbl t);
(* add term to the table *)
Term.Tbl.add cc.tbl t n;
(* [n] might be merged with other equiv classes *)
push_pending cc n;
n
(* add a term *)
and[@inline] add_term_rec_ cc t : node =
try Term.Tbl.find cc.tbl t
with Not_found -> add_new_term_ cc t
let check_invariants_ (cc:t) =
Log.debug 5 "(cc.check-invariants)";
Log.debugf 15 (fun k-> k "%a" pp_full cc);
assert (Term.equal (Term.true_ cc.tst) (true_ cc).n_term);
assert (Term.equal (Term.false_ cc.tst) (false_ cc).n_term);
assert (not @@ same_class cc (true_ cc) (false_ cc));
assert (Vec.is_empty cc.combine);
assert (Vec.is_empty cc.pending);
(* check that subterms are internalized *)
Term.Tbl.iter
(fun t n ->
assert (Term.equal t n.n_term);
assert (not @@ N.get_field N.field_is_pending n);
relevant_subterms t
(fun u -> assert (Term.Tbl.mem cc.tbl u));
assert (N.equal n.n_root n.n_next.n_root);
(* check proper signature.
note that some signatures in the sig table can be obsolete (they
were not removed) but there must be a valid, up-to-date signature for
each term *)
begin match signature cc t with
| None -> ()
| Some s ->
Log.debugf 15 (fun k->k "(@[cc.check-sig@ %a@ :sig %a@])" Term.pp t Signature.pp s);
(* add, but only if not present already *)
begin match Sig_tbl.find cc.signatures_tbl s with
| exception Not_found -> assert false
| repr_s -> assert (same_class cc n repr_s)
end
end;
)
cc.tbl;
()
let[@inline] check_invariants (cc:t) : unit =
if Util._CHECK_INVARIANTS then check_invariants_ cc
let[@inline] add cc t : node = add_term_rec_ cc t
let add_seq cc seq =
seq (fun t -> ignore @@ add_term_rec_ cc t);
()
let[@inline] push_level (self:t) : unit =
Backtrack_stack.push_level self.undo
let pop_levels (self:t) n : unit =
Vec.iter (N.set_field N.field_is_pending false) self.pending;
Vec.clear self.pending;
Vec.clear self.combine;
Log.debugf 15
(fun k->k "(@[cc.pop-levels %d@ :n-lvls %d@])" n (Backtrack_stack.n_levels self.undo));
Backtrack_stack.pop_levels self.undo n ~f:(fun f -> f());
()
(* TODO: if a lit is [= a b], merge [a] and [b];
if it's [distinct a1an], make them distinct, etc. etc. *)
(* assert that this boolean literal holds *)
let assert_lit cc lit : unit =
let t = Lit.view lit in
assert (Ty.is_prop t.term_ty);
Log.debugf 5 (fun k->k "(@[cc.assert_lit@ %a@])" Lit.pp lit);
let sign = Lit.sign lit in
(* equate t and true/false *)
let rhs = if sign then true_ cc else false_ cc in
let n = add_term_rec_ cc t in
(* TODO: ensure that this is O(1).
basically, just have [n] point to true/false and thus acquire
the corresponding value, so its superterms (like [ite]) can evaluate
properly *)
push_combine cc n rhs (Lazy.from_val @@ E_lit lit)
let[@inline] assert_lits cc lits : unit =
Sequence.iter (assert_lit cc) lits
let assert_eq cc (t:term) (u:term) e : unit =
let n1 = add_term_rec_ cc t in
let n2 = add_term_rec_ cc u in
if not (same_class cc n1 n2) then (
let e = Lazy.from_val @@ Explanation.E_lits e in
push_combine cc n1 n2 e;
)
let assert_distinct cc (l:term list) ~neq (lit:Lit.t) : unit =
assert (match l with[] | [_] -> false | _ -> true);
let tag = Term.id neq in
Log.debugf 5
(fun k->k "(@[cc.assert_distinct@ (@[%a@])@ :tag %d@])" (Util.pp_list Term.pp) l tag);
let l = List.map (add cc) l in
Vec.push cc.combine @@ CT_distinct (l, tag, Explanation.lit lit)
let create ?on_merge ?(size=`Big) (tst:Term.state) : t =
let size = match size with `Small -> 128 | `Big -> 2048 in
let rec cc = {
tst;
tbl = Term.Tbl.create size;
signatures_tbl = Sig_tbl.create size;
on_merge;
pending=Vec.create();
combine=Vec.create();
ps_lits=Lit.Set.empty;
undo=Backtrack_stack.create();
ps_queue=Vec.create();
true_;
false_;
} and true_ = lazy (
add_term_rec_ cc (Term.true_ tst)
) and false_ = lazy (
add_term_rec_ cc (Term.false_ tst)
)
in
ignore (Lazy.force true_ : node);
ignore (Lazy.force false_ : node);
cc
let[@inline] find_t cc t : repr =
let n = Term.Tbl.find cc.tbl t in
find cc n
let[@inline] check cc acts : unit =
Log.debug 5 "(cc.check)";
update_tasks cc acts
(* model: map each uninterpreted equiv class to some ID *)
let mk_model (cc:t) (m:Model.t) : Model.t =
Log.debugf 15 (fun k->k "(@[cc.mk_model@ %a@])" pp_full cc);
(* populate [repr -> value] table *)
let t_tbl = N.Tbl.create 32 in
(* type -> default value *)
let ty_tbl = Ty.Tbl.create 8 in
Term.Tbl.values cc.tbl
(fun r ->
if is_root_ r then (
let t = r.n_term in
let v = match Model.eval m t with
| Some v -> v
| None ->
if same_class cc r (true_ cc) then Value.true_
else if same_class cc r (false_ cc) then Value.false_
else (
Value.mk_elt
(ID.makef "v_%d" @@ Term.id t)
(Term.ty r.n_term)
)
in
if not @@ Ty.Tbl.mem ty_tbl (Term.ty t) then (
Ty.Tbl.add ty_tbl (Term.ty t) v; (* also give a value to this type *)
);
N.Tbl.add t_tbl r v
));
(* now map every uninterpreted term to its representative's value, and
create function tables *)
let m, funs =
Term.Tbl.to_seq cc.tbl
|> Sequence.fold
(fun (m,funs) (t,r) ->
let r = find cc r in (* get representative *)
match Term.view t with
| _ when Model.mem t m -> m, funs
| App_cst (c, args) ->
if Model.mem t m then m, funs
else if Cst.is_undefined c && IArray.length args > 0 then (
(* update signature of [c] *)
let ty = Term.ty t in
let v = N.Tbl.find t_tbl r in
let args =
args
|> IArray.map (fun t -> N.Tbl.find t_tbl @@ find_tn cc t)
|> IArray.to_list
in
let ty, l = Cst.Map.get_or c funs ~default:(ty,[]) in
m, Cst.Map.add c (ty, (args,v)::l) funs
) else (
let v = N.Tbl.find t_tbl r in
Model.add t v m, funs
)
| _ ->
let v = N.Tbl.find t_tbl r in
Model.add t v m, funs)
(m,Cst.Map.empty)
in
(* get or make a default value for this type *)
let rec get_ty_default (ty:Ty.t) : Value.t =
match Ty.view ty with
| Ty_prop -> Value.true_
| Ty_atomic { def = Ty_uninterpreted _;_} ->
(* domain element *)
Ty.Tbl.get_or_add ty_tbl ~k:ty
~f:(fun ty -> Value.mk_elt (ID.makef "ty_%d" @@ Ty.id ty) ty)
| Ty_atomic { def = Ty_def d; args; _} ->
(* ask the theory for a default value *)
Ty.Tbl.get_or_add ty_tbl ~k:ty
~f:(fun _ty ->
let vals = List.map get_ty_default args in
d.default_val vals)
in
let funs =
Cst.Map.map
(fun (ty,l) ->
Model.Fun_interpretation.make ~default:(get_ty_default ty) l)
funs
in
Model.add_funs funs m

View file

@ -1,73 +0,0 @@
(** {2 Congruence Closure} *)
open Solver_types
type t
(** Global state of the congruence closure *)
type node = Eq_class.t
(** Node in the congruence closure *)
type repr = Eq_class.t
(** Node that is currently a representative *)
type conflict = Theory.conflict
val create :
?on_merge:(repr -> repr -> explanation -> unit) ->
?size:[`Small | `Big] ->
Term.state ->
t
(** Create a new congruence closure.
@param acts the actions available to the congruence closure
*)
val find : t -> node -> repr
(** Current representative *)
val add : t -> term -> node
(** Add the term to the congruence closure, if not present already.
Will be backtracked. *)
val find_t : t -> term -> repr
(** Current representative of the term.
@raise Not_found if the term is not already {!add}-ed. *)
val add_seq : t -> term Sequence.t -> unit
(** Add a sequence of terms to the congruence closure *)
val all_classes : t -> repr Sequence.t
(** All current classes *)
val assert_lit : t -> Lit.t -> unit
(** Given a literal, assume it in the congruence closure and propagate
its consequences. Will be backtracked. *)
val assert_lits : t -> Lit.t Sequence.t -> unit
val assert_eq : t -> term -> term -> Lit.t list -> unit
val assert_distinct : t -> term list -> neq:term -> Lit.t -> unit
(** [assert_distinct l ~expl:u e] asserts all elements of [l] are distinct
with explanation [e]
precond: [u = distinct l] *)
val check : t -> sat_actions -> unit
(** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc.
Will use the [sat_actions] to propagate literals, declare conflicts, etc. *)
val push_level : t -> unit
val pop_levels : t -> int -> unit
val mk_model : t -> Model.t -> Model.t
(** Enrich a model by mapping terms to their representative's value,
if any. Otherwise map the representative to a fresh value *)
(**/**)
val check_invariants : t -> unit
val pp_full : t Fmt.printer
(**/**)
module T_arg : Mini_cc_intf.ARG with type Fun.t = cst and type Term.t = Term.t
module Mini_cc : module type of Mini_cc.Make(T_arg)

View file

@ -1,66 +0,0 @@
open Solver_types
type t = equiv_class
type payload = equiv_class_payload = ..
let field_is_active = Node_bits.mk_field()
let field_is_pending = Node_bits.mk_field()
let field_is_literal = Node_bits.mk_field()
let () = Node_bits.freeze()
let[@inline] equal (n1:t) n2 = n1==n2
let[@inline] hash n = Term.hash n.n_term
let[@inline] term n = n.n_term
let[@inline] payload n = n.n_payload
let[@inline] pp out n = Term.pp out n.n_term
let make (t:term) : t =
let rec n = {
n_term=t;
n_bits=Node_bits.empty;
n_parents=Bag.empty;
n_root=n;
n_expl=E_none;
n_payload=[];
n_next=n;
n_size=1;
n_tags=Util.Int_map.empty;
} in
n
let set_payload ?(can_erase=fun _->false) n e =
let rec aux = function
| [] -> [e]
| e' :: tail when can_erase e' -> e :: tail
| e' :: tail -> e' :: aux tail
in
n.n_payload <- aux n.n_payload
let payload_find ~f:p n =
let[@unroll 2] rec aux = function
| [] -> None
| e1 :: tail ->
match p e1 with
| Some _ as res -> res
| None -> aux tail
in
aux n.n_payload
let payload_pred ~f:p n =
begin match n.n_payload with
| [] -> false
| e :: _ when p e -> true
| _ :: e :: _ when p e -> true
| _ :: _ :: e :: _ when p e -> true
| l -> List.exists p l
end
let[@inline] get_field f t = Node_bits.get f t.n_bits
let[@inline] set_field f b t = t.n_bits <- Node_bits.set f b t.n_bits
module Tbl = CCHashtbl.Make(struct
type t = equiv_class
let equal = equal
let hash = hash
end)

View file

@ -1,61 +0,0 @@
open Solver_types
(** {1 Equivalence Classes} *)
(** An equivalence class is a set of terms that are currently equal
in the partial model built by the solver.
The class is represented by a collection of nodes, one of which is
distinguished and is called the "representative".
All information pertaining to the whole equivalence class is stored
in this representative's node.
When two classes become equal (are "merged"), one of the two
representatives is picked as the representative of the new class.
The new class contains the union of the two old classes' nodes.
We also allow theories to store additional information in the
representative. This information can be used when two classes are
merged, to detect conflicts and solve equations à la Shostak.
*)
type t = equiv_class
type payload = equiv_class_payload = ..
val field_is_active : Node_bits.field
(** The term is needed for evaluation. We must try to evaluate it
or to find a value for it using the theory *)
val field_is_pending : Node_bits.field
(** true iff the node is in the [cc.pending] queue *)
val field_is_literal : Node_bits.field
(** This term is a boolean literal, subject to propagations *)
(** {2 basics} *)
val term : t -> term
val equal : t -> t -> bool
val hash : t -> int
val pp : t Fmt.printer
val payload : t -> payload list
(** {2 Helpers} *)
val make : term -> t
(** Make a new equivalence class whose representative is the given term *)
val payload_find: f:(payload -> 'a option) -> t -> 'a option
val payload_pred: f:(payload -> bool) -> t -> bool
val set_payload : ?can_erase:(payload -> bool) -> t -> payload -> unit
(** Add given payload
@param can_erase if provided, checks whether an existing value
is to be replaced instead of adding a new entry *)
val get_field : Node_bits.field -> t -> bool
val set_field : Node_bits.field -> bool -> t -> unit
module Tbl : CCHashtbl.S with type key = t

View file

@ -1,26 +0,0 @@
open Solver_types
type t = explanation =
| E_reduction (* by pure reduction, tautologically equal *)
| E_merges of (equiv_class * equiv_class) IArray.t (* caused by these merges *)
| E_lit of lit (* because of this literal *)
| E_lits of lit list (* because of this (true) conjunction *)
let compare = cmp_exp
let equal a b = cmp_exp a b = 0
let pp = pp_explanation
let mk_merges l : t = E_merges l
let mk_lit l : t = E_lit l
let mk_lits = function [x] -> mk_lit x | l -> E_lits l
let mk_reduction : t = E_reduction
let[@inline] lit l : t = E_lit l
module Set = CCSet.Make(struct
type t = explanation
let compare = compare
end)

View file

@ -8,7 +8,7 @@ type t = lit = {
let[@inline] neg l = {l with lit_sign=not l.lit_sign}
let[@inline] sign t = t.lit_sign
let[@inline] view (t:t): term = t.lit_term
let[@inline] term (t:t): term = t.lit_term
let[@inline] abs t: t = {t with lit_sign=true}

View file

@ -10,7 +10,7 @@ type t = lit = {
val neg : t -> t
val abs : t -> t
val sign : t -> bool
val view : t -> term
val term : t -> term
val as_atom : t -> term * bool
val atom : ?sign:bool -> term -> t
val hash : t -> int

View file

@ -1,18 +0,0 @@
(** {1 Mini congruence closure} *)
type ('f, 't, 'ts) view = ('f, 't, 'ts) Mini_cc_intf.view =
| Bool of bool
| App of 'f * 'ts
| If of 't * 't * 't
type res = Mini_cc_intf.res =
| Sat
| Unsat
module type ARG = Mini_cc_intf.ARG
module type S = Mini_cc_intf.S
module Make(A: ARG)
: S with type term = A.Term.t
and type fun_ = A.Fun.t

View file

@ -1,47 +0,0 @@
type ('f, 't, 'ts) view =
| Bool of bool
| App of 'f * 'ts
| If of 't * 't * 't
(* TODO: also HO app, Eq, Distinct cases?
-> then API that just adds boolean terms and does the right thing in case of
Eq/Distinct *)
type res =
| Sat
| Unsat
module type ARG = sig
module Fun : sig
type t
val equal : t -> t -> bool
val hash : t -> int
val pp : t Fmt.printer
end
module Term : sig
type t
val equal : t -> t -> bool
val hash : t -> int
val pp : t Fmt.printer
(** View the term through the lens of the congruence closure *)
val view : t -> (Fun.t, t, t Sequence.t) view
end
end
module type S = sig
type term
type fun_
type t
val create : unit -> t
val merge : t -> term -> term -> unit
val distinct : t -> term list -> unit
val check : t -> res
end

View file

@ -58,6 +58,24 @@ let empty : t = {
funs=Cst.Map.empty;
}
(* FIXME: ues this to allocate a default value for each sort
(* get or make a default value for this type *)
let rec get_ty_default (ty:Ty.t) : Value.t =
match Ty.view ty with
| Ty_prop -> Value.true_
| Ty_atomic { def = Ty_uninterpreted _;_} ->
(* domain element *)
Ty_tbl.get_or_add ty_tbl ~k:ty
~f:(fun ty -> Value.mk_elt (ID.makef "ty_%d" @@ Ty.id ty) ty)
| Ty_atomic { def = Ty_def d; args; _} ->
(* ask the theory for a default value *)
Ty_tbl.get_or_add ty_tbl ~k:ty
~f:(fun _ty ->
let vals = List.map get_ty_default args in
d.default_val vals)
in
*)
let[@inline] mem t m = Term.Map.mem t m.values
let[@inline] find t m = Term.Map.get t m.values
@ -102,9 +120,9 @@ let add_funs fs m : t = merge {values=Term.Map.empty; funs=fs} m
let pp out {values; funs} =
let module FI = Fun_interpretation in
let pp_tv out (t,v) = Fmt.fprintf out "(@[%a@ %a@])" Term.pp t Value.pp v in
let pp_tv out (t,v) = Fmt.fprintf out "(@[%a@ := %a@])" Term.pp t Value.pp v in
let pp_fun_entry out (vals,ret) =
Format.fprintf out "(@[%a@ %a@])" (Fmt.Dump.list Value.pp) vals Value.pp ret
Format.fprintf out "(@[%a@ := %a@])" (Fmt.Dump.list Value.pp) vals Value.pp ret
in
let pp_fun out (c, fi: Cst.t * FI.t) =
Format.fprintf out "(@[<hov>%a :default %a@ %a@])"
@ -127,6 +145,10 @@ let eval (m:t) (t:Term.t) : Value.t option =
| V_bool false -> aux c
| v -> Error.errorf "@[Model: wrong value@ for boolean %a@ %a@]" Term.pp a Value.pp v
end
| Eq(a,b) ->
let a = aux a in
let b = aux b in
if Value.equal a b then Value.true_ else Value.false_
| App_cst (c, args) ->
begin try Term.Map.find t m.values
with Not_found ->

View file

@ -37,10 +37,6 @@ val empty : t
val add : Term.t -> Value.t -> t -> t
val add_fun : Cst.t -> Fun_interpretation.t -> t -> t
val add_funs : Fun_interpretation.t Cst.Map.t -> t -> t
val mem : Term.t -> t -> bool
val find : Term.t -> t -> Value.t option

View file

@ -20,7 +20,6 @@ module Solver = Solver
module Solver_types = Solver_types
(**/**)
module Bag = Bag
module Vec = Msat.Vec
module Log = Msat.Log
(**/**)

View file

@ -208,11 +208,9 @@ let assume (self:t) (c:Lit.t IArray.t) : unit =
let c = IArray.to_array_map (Sat_solver.make_atom sat) c in
Sat_solver.add_clause_a sat c Proof_default
let[@inline] assume_eq self t u expl : unit =
Congruence_closure.assert_eq (cc self) t u [expl]
(* TODO: remove? use a special constant + micro theory instead? *)
let[@inline] assume_distinct self l ~neq lit : unit =
Congruence_closure.assert_distinct (cc self) l lit ~neq
CC.assert_distinct (cc self) l lit ~neq
let check_model (_s:t) : unit =
Log.debug 1 "(smt.solver.check-model)";

View file

@ -47,7 +47,7 @@ val create :
val solver : t -> Sat_solver.t
val th_combine : t -> Theory_combine.t
val add_theory : t -> Theory.t -> unit
val cc : t -> Congruence_closure.t
val cc : t -> CC.t
val stats : t -> Stat.t
val tst : t -> Term.state
@ -56,7 +56,6 @@ val mk_atom_t : t -> ?sign:bool -> Term.t -> Atom.t
val assume : t -> Lit.t IArray.t -> unit
val assume_eq : t -> Term.t -> Term.t -> Lit.t -> unit
val assume_distinct : t -> Term.t list -> neq:Term.t -> Lit.t -> unit
val solve :

View file

@ -3,7 +3,6 @@ module Vec = Msat.Vec
module Log = Msat.Log
module Fmt = CCFormat
module Node_bits = CCBitField.Make(struct end)
(* for objects that are expanded on demand only *)
type 'a lazily_expanded =
@ -21,43 +20,9 @@ type term = {
and 'a term_view =
| Bool of bool
| App_cst of cst * 'a IArray.t (* full, first-order application *)
| Eq of 'a * 'a
| If of 'a * 'a * 'a
(** A node of the congruence closure.
An equivalence class is represented by its "root" element,
the representative.
If there is a normal form in the congruence class, then the
representative is a normal form *)
and equiv_class = {
n_term: term;
mutable n_bits: Node_bits.t; (* bitfield for various properties *)
mutable n_parents: equiv_class Bag.t; (* parent terms of this node *)
mutable n_root: equiv_class; (* representative of congruence class (itself if a representative) *)
mutable n_next: equiv_class; (* pointer to next element of congruence class *)
mutable n_size: int; (* size of the class *)
mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *)
mutable n_payload: equiv_class_payload list; (* list of theory payloads *)
mutable n_tags: (equiv_class * explanation) Util.Int_map.t; (* "distinct" tags (i.e. set of `(distinct t1…tn)` terms this belongs to *)
}
(** Theory-extensible payloads *)
and equiv_class_payload = ..
and explanation_forest_link =
| E_none
| E_some of {
next: equiv_class;
expl: explanation;
}
(* atomic explanation in the congruence closure *)
and explanation =
| E_reduction (* by pure reduction, tautologically equal *)
| E_merges of (equiv_class * equiv_class) IArray.t (* caused by these merges *)
| E_lit of lit (* because of this literal *)
| E_lits of lit list (* because of this (true) conjunction *)
(* boolean literal *)
and lit = {
lit_term: term;
@ -157,23 +122,6 @@ let hash_lit a =
let sign = a.lit_sign in
Hash.combine3 2 (Hash.bool sign) (term_hash_ a.lit_term)
let cmp_cc_node a b = term_cmp_ a.n_term b.n_term
let cmp_exp a b =
let toint = function
| E_merges _ -> 0 | E_lit _ -> 1
| E_reduction -> 2 | E_lits _ -> 3
in
begin match a, b with
| E_merges l1, E_merges l2 ->
IArray.compare (CCOrd.pair cmp_cc_node cmp_cc_node) l1 l2
| E_reduction, E_reduction -> 0
| E_lit l1, E_lit l2 -> cmp_lit l1 l2
| E_lits l1, E_lits l2 -> CCList.compare cmp_lit l1 l2
| E_merges _, _ | E_lit _, _ | E_lits _, _ | E_reduction, _
-> CCInt.compare (toint a)(toint b)
end
let pp_cst out a = ID.pp out a.cst_id
let id_of_cst a = a.cst_id
@ -215,6 +163,7 @@ let pp_term_view_gen ~pp_id ~pp_t out = function
pp_id out (id_of_cst c)
| App_cst (f,l) ->
Fmt.fprintf out "(@[<1>%a@ %a@])" pp_id (id_of_cst f) (Util.pp_iarray pp_t) l
| Eq (a,b) -> Fmt.fprintf out "(@[<hv>=@ %a@ %a@])" pp_t a pp_t b
| If (a, b, c) ->
Fmt.fprintf out "(@[if %a@ %a@ %a@])" pp_t a pp_t b pp_t c
@ -233,14 +182,5 @@ let pp_lit out l =
if l.lit_sign then pp_term out l.lit_term
else Format.fprintf out "(@[@<1>¬@ %a@])" pp_term l.lit_term
let pp_cc_node out n = pp_term out n.n_term
let pp_explanation out (e:explanation) = match e with
| E_reduction -> Fmt.string out "reduction"
| E_lit lit -> pp_lit out lit
| E_lits l -> CCFormat.Dump.list pp_lit out l
| E_merges l ->
Format.fprintf out "(@[<hv1>merges@ %a@])"
Fmt.(seq ~sep:(return "@ ") @@ within "[" "]" @@ hvbox @@
pair ~sep:(return "@ <-> ") pp_cc_node pp_cc_node)
(IArray.to_seq l)
let pp_proof out = function
| Proof_default -> Fmt.fprintf out "<default proof>"

View file

@ -10,6 +10,7 @@ type t = term = {
type 'a view = 'a term_view =
| Bool of bool
| App_cst of cst * 'a IArray.t
| Eq of 'a * 'a
| If of 'a * 'a * 'a
let[@inline] id t = t.term_id
@ -47,6 +48,7 @@ let[@inline] make st (c:t term_view) : t =
let[@inline] true_ st = Lazy.force st.true_
let[@inline] false_ st = Lazy.force st.false_
let bool st b = if b then true_ st else false_ st
let create ?(size=1024) () : state =
let rec st ={
@ -66,9 +68,9 @@ let app_cst st f a =
let cell = Term_cell.app_cst f a in
make st cell
let const st c = app_cst st c IArray.empty
let if_ st a b c = make st (Term_cell.if_ a b c)
let[@inline] const st c = app_cst st c IArray.empty
let[@inline] if_ st a b c = make st (Term_cell.if_ a b c)
let[@inline] eq st a b = make st (Term_cell.eq a b)
(* "eager" and, evaluating [a] first *)
let and_eager st a b = if_ st a b (false_ st)
@ -87,10 +89,12 @@ let[@inline] is_const t = match view t with
| _ -> false
let cc_view (t:t) =
let module C = Mini_cc in
let module C = Sidekick_cc in
match view t with
| Bool b -> C.Bool b
| App_cst (f,args) -> C.App (f, IArray.to_seq args)
| App_cst (f,_) when not (Cst.do_cc f) -> C.Opaque t (* skip *)
| App_cst (f,args) -> C.App_fun (f, IArray.to_seq args)
| Eq (a,b) -> C.Eq (a, b)
| If (a,b,c) -> C.If (a,b,c)
module As_key = struct
@ -109,6 +113,7 @@ let to_seq t yield =
match view t with
| Bool _ -> ()
| App_cst (_,a) -> IArray.iter aux a
| Eq (a,b) -> aux a; aux b
| If (a,b,c) -> aux a; aux b; aux c
in
aux t
@ -121,3 +126,14 @@ let as_cst_undef (t:term): (cst * Ty.Fun.t) option =
let pp = Solver_types.pp_term
(* TODO
module T_arg = struct
module Fun = Cst
module Term = struct
include Term
let view = cc_view
end
end
module Mini_cc = Mini_cc.Make(T_arg)
*)

View file

@ -10,6 +10,7 @@ type t = term = {
type 'a view = 'a term_view =
| Bool of bool
| App_cst of cst * 'a IArray.t
| Eq of 'a * 'a
| If of 'a * 'a * 'a
val id : t -> int
@ -26,8 +27,10 @@ val create : ?size:int -> unit -> state
val make : state -> t view -> t
val true_ : state -> t
val false_ : state -> t
val bool : state -> bool -> t
val const : state -> cst -> t
val app_cst : state -> cst -> t IArray.t -> t
val eq : state -> t -> t -> t
val if_: state -> t -> t -> t -> t
val and_eager : state -> t -> t -> t (* evaluate left argument first *)
@ -49,7 +52,7 @@ val is_true : t -> bool
val is_false : t -> bool
val is_const : t -> bool
val cc_view : t -> (cst,t,t Sequence.t) Mini_cc.view
val cc_view : t -> (cst,t,t Sequence.t) Sidekick_cc.view
(* return [Some] iff the term is an undefined constant *)
val as_cst_undef : t -> (cst * Ty.Fun.t) option

View file

@ -6,6 +6,7 @@ open Solver_types
type 'a view = 'a Solver_types.term_view =
| Bool of bool
| App_cst of cst * 'a IArray.t
| Eq of 'a * 'a
| If of 'a * 'a * 'a
type t = term view
@ -25,6 +26,7 @@ module Make_eq(A : ARG) = struct
| Bool b -> Hash.bool b
| App_cst (f,l) ->
Hash.combine3 4 (Cst.hash f) (Hash.iarray sub_hash l)
| Eq (a,b) -> Hash.combine3 12 (sub_hash a) (sub_hash b)
| If (a,b,c) -> Hash.combine4 7 (sub_hash a) (sub_hash b) (sub_hash c)
(* equality that relies on physical equality of subterms *)
@ -32,9 +34,10 @@ module Make_eq(A : ARG) = struct
| Bool b1, Bool b2 -> CCBool.equal b1 b2
| App_cst (f1, a1), App_cst (f2, a2) ->
Cst.equal f1 f2 && IArray.equal sub_eq a1 a2
| Eq(a1,b1), Eq(a2,b2) -> sub_eq a1 a2 && sub_eq b1 b2
| If (a1,b1,c1), If (a2,b2,c2) ->
sub_eq a1 a2 && sub_eq b1 b2 && sub_eq c1 c2
| Bool _, _ | App_cst _, _ | If _, _
| Bool _, _ | App_cst _, _ | If _, _ | Eq _, _
-> false
let pp = Solver_types.pp_term_view_gen ~pp_id:ID.pp_name ~pp_t:A.pp
@ -53,17 +56,25 @@ let false_ = Bool false
let is_value = function
| Bool _ -> true
| App_cst ({cst_view=Cst_def r;_}, _) -> r.is_value
| If _ | App_cst _ -> false
| If _ | App_cst _ | Eq _ -> false
let app_cst f a = App_cst (f, a)
let const c = App_cst (c, IArray.empty)
let eq a b =
if term_equal_ a b then (
Bool true
) else (
(* canonize *)
let a,b = if a.term_id > b.term_id then b, a else a, b in
Eq (a,b)
)
let if_ a b c =
assert (Ty.equal b.term_ty c.term_ty);
If (a,b,c)
let ty (t:t): Ty.t = match t with
| Bool _ -> Ty.prop
| Bool _ | Eq _ -> Ty.prop
| App_cst (f, args) ->
begin match Cst.view f with
| Cst_undef fty ->

View file

@ -4,6 +4,7 @@ open Solver_types
type 'a view = 'a Solver_types.term_view =
| Bool of bool
| App_cst of cst * 'a IArray.t
| Eq of 'a * 'a
| If of 'a * 'a * 'a
type t = term view
@ -15,6 +16,7 @@ val true_ : t
val false_ : t
val const : cst -> t
val app_cst : cst -> term IArray.t -> t
val eq : term -> term -> t
val if_ : term -> term -> term -> t
val is_value : t -> bool

View file

@ -18,6 +18,9 @@ end
Its negation will become a conflict clause *)
type conflict = Lit.t list
module CC_eq_class = CC.N
module CC_expl = CC.Expl
(** Actions available to a theory during its lifetime *)
module type ACTIONS = sig
val raise_conflict: conflict -> 'a
@ -41,12 +44,15 @@ module type ACTIONS = sig
(** Add toplevel clause to the SAT solver. This clause will
not be backtracked. *)
val find: Term.t -> Eq_class.t
(** Find representative of this term *)
val cc_add_term: Term.t -> CC_eq_class.t
(** add/get term to the congruence closure *)
val all_classes: Eq_class.t Sequence.t
val cc_find: CC_eq_class.t -> CC_eq_class.t
(** Find representative of this in the congruence closure *)
val cc_all_classes: CC_eq_class.t Sequence.t
(** All current equivalence classes
(caution: linear in the number of terms existing in the solver) *)
(caution: linear in the number of terms existing in the congruence closure) *)
end
type actions = (module ACTIONS)
@ -60,7 +66,7 @@ module type S = sig
val create : Term.state -> t
(** Instantiate the theory's state *)
val on_merge: t -> actions -> Eq_class.t -> Eq_class.t -> Explanation.t -> unit
val on_merge: t -> actions -> CC_eq_class.t -> CC_eq_class.t -> CC_expl.t -> unit
(** Called when two classes are merged *)
val partial_check : t -> actions -> Lit.t Sequence.t -> unit
@ -70,7 +76,7 @@ module type S = sig
(** Final check, must be complete (i.e. must raise a conflict
if the set of literals is not satisfiable) *)
val mk_model : t -> Lit.t Sequence.t -> Model.t
val mk_model : t -> Lit.t Sequence.t -> Model.t -> Model.t
(** Make a model for this theory's terms *)
val push_level : t -> unit
@ -91,7 +97,7 @@ let make
?(check_invariants=fun _ -> ())
?(on_merge=fun _ _ _ _ _ -> ())
?(partial_check=fun _ _ _ -> ())
?(mk_model=fun _ _ -> Model.empty)
?(mk_model=fun _ _ m -> m)
?(push_level=fun _ -> ())
?(pop_levels=fun _ _ -> ())
~name

View file

@ -3,7 +3,6 @@
(** Combine the congruence closure with a number of plugins *)
module C_clos = Congruence_closure
open Solver_types
module Proof = struct
@ -12,6 +11,8 @@ module Proof = struct
end
module Formula = Lit
module Eq_class = CC.N
module Expl = CC.Expl
type formula = Lit.t
type proof = Proof.t
@ -24,11 +25,11 @@ type theory_state =
type t = {
tst: Term.state;
(** state for managing terms *)
cc: C_clos.t lazy_t;
cc: CC.t lazy_t;
(** congruence closure *)
mutable theories : theory_state list;
(** Set of theories *)
new_merges: (Eq_class.t * Eq_class.t * explanation) Vec.t;
new_merges: (Eq_class.t * Eq_class.t * Expl.t) Vec.t;
}
let[@inline] cc (t:t) = Lazy.force t.cc
@ -41,24 +42,28 @@ let[@inline] theories (self:t) : theory_state Sequence.t =
(* handle a literal assumed by the SAT solver *)
let assert_lits_ ~final (self:t) acts (lits:Lit.t Sequence.t) : unit =
Msat.Log.debugf 2
(fun k->k "(@[<1>@{<green>th_combine.assume_lits@}@ @[%a@]@])" (Fmt.seq Lit.pp) lits);
(fun k->k "(@[<hv1>@{<green>th_combine.assume_lits@}@ %a@])"
(Util.pp_seq ~sep:";" Lit.pp) lits);
(* transmit to CC *)
Vec.clear self.new_merges;
let cc = cc self in
C_clos.assert_lits cc lits;
if not final then (
CC.assert_lits cc lits;
);
(* transmit to theories. *)
C_clos.check cc acts;
CC.check cc acts;
let module A = struct
let[@inline] raise_conflict c : 'a = acts.Msat.acts_raise_conflict c Proof_default
let[@inline] propagate_eq t u expl : unit = C_clos.assert_eq cc t u expl
let propagate_distinct ts ~neq expl = C_clos.assert_distinct cc ts ~neq expl
let[@inline] propagate_eq t u expl : unit = CC.assert_eq cc t u expl
let propagate_distinct ts ~neq expl = CC.assert_distinct cc ts ~neq expl
let[@inline] propagate p cs : unit = acts.Msat.acts_propagate p (Msat.Consequence (cs, Proof_default))
let[@inline] add_local_axiom lits : unit =
acts.Msat.acts_add_clause ~keep:false lits Proof_default
let[@inline] add_persistent_axiom lits : unit =
acts.Msat.acts_add_clause ~keep:true lits Proof_default
let[@inline] find t = C_clos.find_t cc t
let all_classes = C_clos.all_classes cc
let[@inline] cc_add_term t = CC.add_term cc t
let[@inline] cc_find t = CC.find cc t
let cc_all_classes = CC.all_classes cc
end in
let acts = (module A : Theory.ACTIONS) in
theories self
@ -83,10 +88,10 @@ let check_ ~final (self:t) (acts:_ Msat.acts) =
assert_lits_ ~final self acts iter
let add_formula (self:t) (lit:Lit.t) =
let t = Lit.view lit in
let t = Lit.term lit in
let lazy cc = self.cc in
let n = C_clos.add cc t in
Eq_class.set_field Eq_class.field_is_literal true n;
let n = CC.add_term cc t in
CC.set_as_lit cc n (Lit.abs lit);
()
(* propagation from the bool solver *)
@ -98,21 +103,21 @@ let[@inline] final_check (self:t) (acts:_ Msat.acts) : unit =
check_ ~final:true self acts
let push_level (self:t) : unit =
C_clos.push_level (cc self);
CC.push_level (cc self);
theories self (fun (Th_state ((module Th), st)) -> Th.push_level st)
let pop_levels (self:t) n : unit =
C_clos.pop_levels (cc self) n;
CC.pop_levels (cc self) n;
theories self (fun (Th_state ((module Th), st)) -> Th.pop_levels st n)
let mk_model (self:t) lits : Model.t =
let m =
Sequence.fold
(fun m (Th_state ((module Th),st)) -> Model.merge m (Th.mk_model st lits))
(fun m (Th_state ((module Th),st)) -> Th.mk_model st lits m)
Model.empty (theories self)
in
(* now complete model using CC *)
Congruence_closure.mk_model (cc self) m
CC.mk_model (cc self) m
(** {2 Interface to Congruence Closure} *)
@ -131,16 +136,16 @@ let create () : t =
cc = lazy (
(* lazily tie the knot *)
let on_merge = on_merge_from_cc self in
C_clos.create ~on_merge ~size:`Big self.tst;
CC.create ~on_merge ~size:`Big self.tst;
);
theories = [];
} in
ignore (Lazy.force @@ self.cc : C_clos.t);
ignore (Lazy.force @@ self.cc : CC.t);
self
let check_invariants (self:t) =
if Util._CHECK_INVARIANTS then (
Congruence_closure.check_invariants (cc self);
CC.check_invariants (cc self);
)
let add_theory (self:t) (th:Theory.t) : unit =

View file

@ -13,7 +13,7 @@ include Msat.Solver_intf.PLUGIN_CDCL_T
val create : unit -> t
val cc : t -> Congruence_closure.t
val cc : t -> CC.t
val tst : t -> Term.state
type theory_state =

View file

@ -19,3 +19,5 @@ let equal = eq_value
let hash = hash_value
let pp = pp_value
let fresh (t:term) : t =
mk_elt (ID.makef "v_%d" t.term_id) t.term_ty

View file

@ -15,6 +15,8 @@ val is_bool : t -> bool
val is_true : t -> bool
val is_false : t -> bool
val fresh : Term.t -> t
include Intf.EQ with type t := t
include Intf.HASH with type t := t
include Intf.PRINT with type t := t

View file

@ -2,7 +2,8 @@
(library
(name Sidekick_smt)
(public_name sidekick.smt)
(libraries containers containers.data sequence sidekick.util msat zarith)
(libraries containers containers.data sequence
sidekick.util sidekick.cc msat zarith)
(flags :standard -warn-error -a+8
-color always -safe-string -short-paths -open Sidekick_util)
(ocamlopt_flags :standard -O3 -color always

View file

@ -15,7 +15,6 @@ let id_not = ID.make "not"
let id_and = ID.make "and"
let id_or = ID.make "or"
let id_imply = ID.make "=>"
let id_eq = ID.make "="
let id_distinct = ID.make "distinct"
type 'a view =
@ -32,8 +31,6 @@ exception Not_a_th_term
let view_id cst_id args =
if ID.equal cst_id id_not && IArray.length args=1 then (
B_not (IArray.get args 0)
) else if ID.equal cst_id id_eq && IArray.length args=2 then (
B_eq (IArray.get args 0, IArray.get args 1)
) else if ID.equal cst_id id_and then (
B_and args
) else if ID.equal cst_id id_or then (
@ -45,13 +42,14 @@ let view_id cst_id args =
) else if ID.equal cst_id id_distinct then (
B_distinct args
) else (
raise Not_a_th_term
raise_notrace Not_a_th_term
)
let view (t:Term.t) : term view =
match Term.view t with
| Eq (a,b) -> B_eq (a,b)
| App_cst ({cst_id; _}, args) ->
(try view_id cst_id args with Not_a_th_term -> B_atom t)
begin try view_id cst_id args with Not_a_th_term -> B_atom t end
| _ -> B_atom t
@ -59,9 +57,6 @@ module C = struct
let get_ty _ _ = Ty.prop
(* no congruence closure, except for `=` *)
let relevant id _ _ = ID.equal id_eq id
let abs ~self _a =
match Term.view self with
| App_cst ({cst_id;_}, args) when ID.equal cst_id id_not && IArray.length args=1 ->
@ -89,6 +84,9 @@ module C = struct
| B_not _ | B_and _ | B_or _ | B_imply _
-> Error.errorf "non boolean value in boolean connective"
(* no congruence closure for boolean terms *)
let relevant _id _ _ = false
let mk_cst ?(do_cc=false) id : Cst.t =
{cst_id=id;
cst_view=Cst_def {
@ -98,7 +96,6 @@ module C = struct
let and_ = mk_cst id_and
let or_ = mk_cst id_or
let imply = mk_cst id_imply
let eq = mk_cst ~do_cc:true id_eq
let distinct = mk_cst id_distinct
end
@ -134,13 +131,7 @@ let or_l st l =
let and_ st a b = and_l st [a;b]
let or_ st a b = or_l st [a;b]
let eq st a b =
if Term.equal a b then (
Term.true_ st
) else (
let a,b = if Term.id a > Term.id b then b, a else a, b in
Term.app_cst st C.eq (IArray.doubleton a b)
)
let eq = Term.eq
let not_ st a =
match as_id id_not a, Term.view a with
@ -164,7 +155,7 @@ let distinct st = function
module Lit = struct
include Lit
let eq tst a b = Lit.atom ~sign:true (eq tst a b)
let neq tst a b = Lit.atom ~sign:false (neq tst a b)
let neq tst a b = neg @@ eq tst a b
end
type t = {
@ -175,14 +166,8 @@ let tseitin (_self:t) (acts:Theory.actions) (lit:Lit.t) (lit_t:term) (v:term vie
let (module A) = acts in
Log.debugf 5 (fun k->k "(@[th_bool.tseitin@ %a@])" Lit.pp lit);
match v with
| B_atom _ -> ()
| B_not _ -> assert false (* normalized *)
| B_eq (t,u) ->
if Lit.sign lit then (
A.propagate_eq t u [lit]
) else (
A.propagate_distinct [t;u] ~neq:lit_t lit
)
| B_atom _ | B_eq _ -> () (* CC will manage *)
| B_distinct l ->
let l = IArray.to_list l in
if Lit.sign lit then (
@ -197,7 +182,7 @@ let tseitin (_self:t) (acts:Theory.actions) (lit:Lit.t) (lit_t:term) (v:term vie
IArray.iter
(fun sub ->
let sublit = Lit.atom sub in
A.propagate sublit [lit])
A.add_local_axiom [Lit.neg lit; sublit])
subs
) else (
(* propagate [¬lit => _i ¬ subs_i] *)
@ -216,7 +201,7 @@ let tseitin (_self:t) (acts:Theory.actions) (lit:Lit.t) (lit_t:term) (v:term vie
IArray.iter
(fun sub ->
let sublit = Lit.atom ~sign:false sub in
A.propagate sublit [lit])
A.add_local_axiom [Lit.neg lit; sublit])
subs
)
| B_imply (guard,concl) ->
@ -239,7 +224,7 @@ let tseitin (_self:t) (acts:Theory.actions) (lit:Lit.t) (lit_t:term) (v:term vie
let partial_check (self:t) acts (lits:Lit.t Sequence.t) =
lits
(fun lit ->
let t = Lit.view lit in
let t = Lit.term lit in
match view t with
| B_atom _ -> ()
| v -> tseitin self acts lit t v)

View file

@ -9,3 +9,4 @@ module Backtrack_stack = Backtrack_stack
module Error = Error
module IArray = IArray
module Intf = Intf
module Bag = Bag