mirror of
https://github.com/c-cube/sidekick.git
synced 2026-01-22 01:06:43 -05:00
feat(cc): split sub-library sidekick.cc, make it fully functorized
This commit is contained in:
parent
de1653bdcc
commit
a463dbb4b5
39 changed files with 1558 additions and 1237 deletions
112
src/cc/CC_types.ml
Normal file
112
src/cc/CC_types.ml
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
|
||||
(** {1 Types used by the congruence closure} *)
|
||||
|
||||
type ('f, 't, 'ts) view =
|
||||
| Bool of bool
|
||||
| App_fun of 'f * 'ts
|
||||
| App_ho of 't * 'ts
|
||||
| If of 't * 't * 't
|
||||
| Eq of 't * 't
|
||||
| Opaque of 't (* do not enter *)
|
||||
|
||||
let[@inline] map_view ~f_f ~f_t ~f_ts (v:_ view) : _ view =
|
||||
match v with
|
||||
| Bool b -> Bool b
|
||||
| App_fun (f, args) -> App_fun (f_f f, f_ts args)
|
||||
| App_ho (f, args) -> App_ho (f_t f, f_ts args)
|
||||
| If (a,b,c) -> If (f_t a, f_t b, f_t c)
|
||||
| Eq (a,b) -> Eq (f_t a, f_t b)
|
||||
| Opaque t -> Opaque (f_t t)
|
||||
|
||||
let iter_view ~f_f ~f_t ~f_ts (v:_ view) : unit =
|
||||
match v with
|
||||
| Bool _ -> ()
|
||||
| App_fun (f, args) -> f_f f; f_ts args
|
||||
| App_ho (f, args) -> f_t f; f_ts args
|
||||
| If (a,b,c) -> f_t a; f_t b; f_t c;
|
||||
| Eq (a,b) -> f_t a; f_t b
|
||||
| Opaque t -> f_t t
|
||||
|
||||
module type TERM = sig
|
||||
module Fun : sig
|
||||
type t
|
||||
val equal : t -> t -> bool
|
||||
val hash : t -> int
|
||||
val pp : t Fmt.printer
|
||||
end
|
||||
|
||||
module Term : sig
|
||||
type t
|
||||
val equal : t -> t -> bool
|
||||
val hash : t -> int
|
||||
val pp : t Fmt.printer
|
||||
|
||||
type state
|
||||
|
||||
val bool : state -> bool -> t
|
||||
|
||||
(** View the term through the lens of the congruence closure *)
|
||||
val cc_view : t -> (Fun.t, t, t Sequence.t) view
|
||||
end
|
||||
end
|
||||
|
||||
module type TERM_LIT = sig
|
||||
include TERM
|
||||
|
||||
module Lit : sig
|
||||
type t
|
||||
val neg : t -> t
|
||||
val equal : t -> t -> bool
|
||||
val hash : t -> int
|
||||
val pp : t Fmt.printer
|
||||
|
||||
val sign : t -> bool
|
||||
val term : t -> Term.t
|
||||
end
|
||||
end
|
||||
|
||||
module type FULL = sig
|
||||
include TERM_LIT
|
||||
|
||||
module Proof : sig
|
||||
type t
|
||||
val pp : t Fmt.printer
|
||||
|
||||
val default : t
|
||||
(* TODO: to give more details
|
||||
val cc_lemma : unit -> t
|
||||
*)
|
||||
end
|
||||
|
||||
module Ty : sig
|
||||
type t
|
||||
|
||||
val equal : t -> t -> bool
|
||||
val hash : t -> int
|
||||
val pp : t Fmt.printer
|
||||
end
|
||||
|
||||
module Value : sig
|
||||
type t
|
||||
|
||||
val pp : t Fmt.printer
|
||||
|
||||
val fresh : Term.t -> t
|
||||
|
||||
val true_ : t
|
||||
val false_ : t
|
||||
end
|
||||
|
||||
module Model : sig
|
||||
type t
|
||||
|
||||
val pp : t Fmt.printer
|
||||
|
||||
val eval : t -> Term.t -> Value.t option
|
||||
(** Evaluate the term in the current model *)
|
||||
|
||||
val add : Term.t -> Value.t -> t -> t
|
||||
end
|
||||
end
|
||||
|
||||
(* TODO: micro theory *)
|
||||
939
src/cc/Congruence_closure.ml
Normal file
939
src/cc/Congruence_closure.ml
Normal file
|
|
@ -0,0 +1,939 @@
|
|||
|
||||
open CC_types
|
||||
|
||||
module type ARG = Congruence_closure_intf.ARG
|
||||
module type S = Congruence_closure_intf.S
|
||||
|
||||
module Bits = CCBitField.Make()
|
||||
|
||||
let field_is_pending = Bits.mk_field()
|
||||
(** true iff the node is in the [cc.pending] queue *)
|
||||
|
||||
let () = Bits.freeze()
|
||||
|
||||
type payload = Congruence_closure_intf.payload = ..
|
||||
|
||||
module Make(A: ARG) = struct
|
||||
type term = A.Term.t
|
||||
type term_state = A.Term.state
|
||||
type lit = A.Lit.t
|
||||
type fun_ = A.Fun.t
|
||||
type proof = A.Proof.t
|
||||
type value = A.Value.t
|
||||
type model = A.Model.t
|
||||
|
||||
(** Actions available to the theory *)
|
||||
type sat_actions = (Msat.void, lit, Msat.void, proof) Msat.acts
|
||||
|
||||
module T = A.Term
|
||||
module Fun = A.Fun
|
||||
|
||||
(** A node of the congruence closure.
|
||||
An equivalence class is represented by its "root" element,
|
||||
the representative. *)
|
||||
type node = {
|
||||
n_term: term;
|
||||
mutable n_sig0: signature option; (* initial signature *)
|
||||
mutable n_bits: Bits.t; (* bitfield for various properties *)
|
||||
mutable n_parents: node Bag.t; (* parent terms of this node *)
|
||||
mutable n_root: node; (* representative of congruence class (itself if a representative) *)
|
||||
mutable n_next: node; (* pointer to next element of congruence class *)
|
||||
mutable n_size: int; (* size of the class *)
|
||||
mutable n_as_lit: lit option;
|
||||
mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *)
|
||||
mutable n_payload: payload list; (* list of theory payloads *)
|
||||
(* TODO: make a micro theory and move this inside *)
|
||||
mutable n_tags: (node * explanation) Util.Int_map.t;
|
||||
(* "distinct" tags (i.e. set of `(distinct t1…tn)` terms this belongs to *)
|
||||
}
|
||||
|
||||
and signature = (fun_, node, node list) view
|
||||
|
||||
and explanation_forest_link =
|
||||
| FL_none
|
||||
| FL_some of {
|
||||
next: node;
|
||||
expl: explanation;
|
||||
}
|
||||
|
||||
(* atomic explanation in the congruence closure *)
|
||||
and explanation =
|
||||
| E_reduction (* by pure reduction, tautologically equal *)
|
||||
| E_merges of (node * node) list (* caused by these merges *)
|
||||
| E_lit of lit (* because of this literal *)
|
||||
| E_lits of lit list (* because of this (true) conjunction *)
|
||||
(* TODO: congruence case (cheaper than "merges") *)
|
||||
|
||||
type repr = node
|
||||
type conflict = lit list
|
||||
|
||||
module N = struct
|
||||
type t = node
|
||||
|
||||
let[@inline] equal (n1:t) n2 = T.equal n1.n_term n2.n_term
|
||||
let[@inline] hash n = T.hash n.n_term
|
||||
let[@inline] term n = n.n_term
|
||||
let[@inline] payload n = n.n_payload
|
||||
let[@inline] pp out n = T.pp out n.n_term
|
||||
let[@inline] as_lit n = n.n_as_lit
|
||||
|
||||
let make (t:term) : t =
|
||||
let rec n = {
|
||||
n_term=t;
|
||||
n_sig0= None;
|
||||
n_bits=Bits.empty;
|
||||
n_parents=Bag.empty;
|
||||
n_as_lit=None; (* TODO: provide a method to do it *)
|
||||
n_root=n;
|
||||
n_expl=FL_none;
|
||||
n_payload=[];
|
||||
n_next=n;
|
||||
n_size=1;
|
||||
n_tags=Util.Int_map.empty;
|
||||
} in
|
||||
n
|
||||
|
||||
type nonrec payload = payload = ..
|
||||
|
||||
let set_payload ?(can_erase=fun _->false) n e =
|
||||
let rec aux = function
|
||||
| [] -> [e]
|
||||
| e' :: tail when can_erase e' -> e :: tail
|
||||
| e' :: tail -> e' :: aux tail
|
||||
in
|
||||
n.n_payload <- aux n.n_payload
|
||||
|
||||
let payload_find ~f:p n =
|
||||
let[@unroll 2] rec aux = function
|
||||
| [] -> None
|
||||
| e1 :: tail ->
|
||||
match p e1 with
|
||||
| Some _ as res -> res
|
||||
| None -> aux tail
|
||||
in
|
||||
aux n.n_payload
|
||||
|
||||
let payload_pred ~f:p n =
|
||||
begin match n.n_payload with
|
||||
| [] -> false
|
||||
| e :: _ when p e -> true
|
||||
| _ :: e :: _ when p e -> true
|
||||
| _ :: _ :: e :: _ when p e -> true
|
||||
| l -> List.exists p l
|
||||
end
|
||||
|
||||
let[@inline] get_field f t = Bits.get f t.n_bits
|
||||
let[@inline] set_field f b t = t.n_bits <- Bits.set f b t.n_bits
|
||||
end
|
||||
|
||||
module N_tbl = CCHashtbl.Make(N)
|
||||
|
||||
module Expl = struct
|
||||
type t = explanation
|
||||
|
||||
let equal (a:t) b =
|
||||
match a, b with
|
||||
| E_merges l1, E_merges l2 ->
|
||||
CCList.equal (CCPair.equal N.equal N.equal) l1 l2
|
||||
| E_reduction, E_reduction -> true
|
||||
| E_lit l1, E_lit l2 -> A.Lit.equal l1 l2
|
||||
| E_lits l1, E_lits l2 -> CCList.equal A.Lit.equal l1 l2
|
||||
| E_merges _, _ | E_lit _, _ | E_lits _, _ | E_reduction, _
|
||||
-> false
|
||||
|
||||
let hash (a:t) : int =
|
||||
let module H = CCHash in
|
||||
match a with
|
||||
| E_lit lit -> H.combine2 10 (A.Lit.hash lit)
|
||||
| E_lits l ->
|
||||
H.combine2 20 (H.list A.Lit.hash l)
|
||||
| E_merges l ->
|
||||
H.combine2 30 (H.list (H.pair N.hash N.hash) l)
|
||||
| E_reduction -> H.int 40
|
||||
|
||||
let pp out (e:explanation) = match e with
|
||||
| E_reduction -> Fmt.string out "reduction"
|
||||
| E_lit lit -> A.Lit.pp out lit
|
||||
| E_lits l -> CCFormat.Dump.list A.Lit.pp out l
|
||||
| E_merges l ->
|
||||
Format.fprintf out "(@[<hv1>merges@ %a@])"
|
||||
Fmt.(seq ~sep:(return "@ ") @@ within "[" "]" @@ hvbox @@
|
||||
pair ~sep:(return " ~@ ") N.pp N.pp)
|
||||
(Sequence.of_list l)
|
||||
|
||||
let[@inline] mk_merges l : t = E_merges l
|
||||
let[@inline] mk_lit l : t = E_lit l
|
||||
let[@inline] mk_lits = function [x] -> mk_lit x | l -> E_lits l
|
||||
let mk_reduction : t = E_reduction
|
||||
end
|
||||
|
||||
(** A signature is a shallow term shape where immediate subterms
|
||||
are representative *)
|
||||
module Signature = struct
|
||||
type t = signature
|
||||
let equal (s1:t) s2 : bool =
|
||||
match s1, s2 with
|
||||
| Bool b1, Bool b2 -> b1=b2
|
||||
| App_fun (f1,[]), App_fun (f2,[]) -> Fun.equal f1 f2
|
||||
| App_fun (f1,l1), App_fun (f2,l2) ->
|
||||
Fun.equal f1 f2 && CCList.equal N.equal l1 l2
|
||||
| App_ho (f1,l1), App_ho (f2,l2) ->
|
||||
N.equal f1 f2 && CCList.equal N.equal l1 l2
|
||||
| If (a1,b1,c1), If (a2,b2,c2) ->
|
||||
N.equal a1 a2 && N.equal b1 b2 && N.equal c1 c2
|
||||
| Eq (a1,b1), Eq (a2,b2) ->
|
||||
N.equal a1 a2 && N.equal b1 b2
|
||||
| Opaque u1, Opaque u2 -> N.equal u1 u2
|
||||
| Bool _, _ | App_fun _, _ | App_ho _, _ | If _, _
|
||||
| Eq _, _ | Opaque _, _
|
||||
-> false
|
||||
|
||||
let hash (s:t) : int =
|
||||
let module H = CCHash in
|
||||
match s with
|
||||
| Bool b -> H.combine2 10 (H.bool b)
|
||||
| App_fun (f, l) -> H.combine3 20 (Fun.hash f) (H.list N.hash l)
|
||||
| App_ho (f, l) -> H.combine3 30 (N.hash f) (H.list N.hash l)
|
||||
| Eq (a,b) -> H.combine3 40 (N.hash a) (N.hash b)
|
||||
| Opaque u -> H.combine2 50 (N.hash u)
|
||||
| If (a,b,c) -> H.combine4 60 (N.hash a)(N.hash b)(N.hash c)
|
||||
|
||||
let pp out = function
|
||||
| Bool b -> Fmt.bool out b
|
||||
| App_fun (f, []) -> Fun.pp out f
|
||||
| App_fun (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Fun.pp f (Util.pp_list N.pp) l
|
||||
| App_ho (f, []) -> N.pp out f
|
||||
| App_ho (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" N.pp f (Util.pp_list N.pp) l
|
||||
| Opaque t -> N.pp out t
|
||||
| Eq (a,b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" N.pp a N.pp b
|
||||
| If (a,b,c) -> Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" N.pp a N.pp b N.pp c
|
||||
end
|
||||
|
||||
module Sig_tbl = CCHashtbl.Make(Signature)
|
||||
module T_tbl = CCHashtbl.Make(T)
|
||||
|
||||
type combine_task =
|
||||
| CT_merge of node * node * explanation
|
||||
| CT_distinct of node list * int * explanation
|
||||
|
||||
type t = {
|
||||
tst: term_state;
|
||||
tbl: node T_tbl.t;
|
||||
(* internalization [term -> node] *)
|
||||
signatures_tbl : node Sig_tbl.t;
|
||||
(* map a signature to the corresponding node in some equivalence class.
|
||||
A signature is a [term_cell] in which every immediate subterm
|
||||
that participates in the congruence/evaluation relation
|
||||
is normalized (i.e. is its own representative).
|
||||
The critical property is that all members of an equivalence class
|
||||
that have the same "shape" (including head symbol)
|
||||
have the same signature *)
|
||||
pending: node Vec.t;
|
||||
combine: combine_task Vec.t;
|
||||
undo: (unit -> unit) Backtrack_stack.t;
|
||||
on_merge: (repr -> repr -> explanation -> unit) option;
|
||||
mutable ps_lits: lit list; (* TODO: thread it around instead? *)
|
||||
(* proof state *)
|
||||
ps_queue: (node*node) Vec.t;
|
||||
(* pairs to explain *)
|
||||
true_ : node lazy_t;
|
||||
false_ : node lazy_t;
|
||||
}
|
||||
(* TODO: an additional union-find to keep track, for each term,
|
||||
of the terms they are known to be equal to, according
|
||||
to the current explanation. That allows not to prove some equality
|
||||
several times.
|
||||
See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *)
|
||||
|
||||
let[@inline] is_root_ (n:node) : bool = n.n_root == n
|
||||
let[@inline] size_ (r:repr) = r.n_size
|
||||
let[@inline] true_ cc = Lazy.force cc.true_
|
||||
let[@inline] false_ cc = Lazy.force cc.false_
|
||||
|
||||
let[@inline] on_backtrack cc f : unit =
|
||||
Backtrack_stack.push_if_nonzero_level cc.undo f
|
||||
|
||||
(* check if [t] is in the congruence closure.
|
||||
Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *)
|
||||
let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t
|
||||
|
||||
(* find representative, recursively *)
|
||||
let[@unroll 1] rec find_rec (n:node) : repr =
|
||||
if n==n.n_root then (
|
||||
n
|
||||
) else (
|
||||
(* TODO: path compression, assuming backtracking restores equiv classes
|
||||
properly *)
|
||||
let root = find_rec n.n_root in
|
||||
root
|
||||
)
|
||||
|
||||
(* traverse the equivalence class of [n] *)
|
||||
let iter_class_ (n:node) : node Sequence.t =
|
||||
fun yield ->
|
||||
let rec aux u =
|
||||
yield u;
|
||||
if u.n_next != n then aux u.n_next
|
||||
in
|
||||
aux n
|
||||
|
||||
(* non-recursive, inlinable function for [find] *)
|
||||
let[@inline] find_ (n:node) : repr =
|
||||
if n == n.n_root then n else find_rec n.n_root
|
||||
|
||||
let[@inline] same_class (n1:node)(n2:node): bool =
|
||||
N.equal (find_ n1) (find_ n2)
|
||||
|
||||
let[@inline] find _ n = find_ n
|
||||
|
||||
(* print full state *)
|
||||
let pp_full out (cc:t) : unit =
|
||||
let pp_next out n =
|
||||
Fmt.fprintf out "@ :next %a" N.pp n.n_next in
|
||||
let pp_root out n =
|
||||
if is_root_ n then Fmt.string out " :is-root" else Fmt.fprintf out "@ :root %a" N.pp n.n_root in
|
||||
let pp_expl out n = match n.n_expl with
|
||||
| FL_none -> ()
|
||||
| FL_some e ->
|
||||
Fmt.fprintf out " (@[:forest %a :expl %a@])" N.pp e.next Expl.pp e.expl
|
||||
in
|
||||
let pp_n out n =
|
||||
Fmt.fprintf out "(@[%a%a%a%a@])" T.pp n.n_term pp_root n pp_next n pp_expl n
|
||||
and pp_sig_e out (s,n) =
|
||||
Fmt.fprintf out "(@[<1>%a@ ~~> %a%a@])" Signature.pp s N.pp n pp_root n
|
||||
in
|
||||
Fmt.fprintf out
|
||||
"(@[@{<yellow>cc.state@}@ (@[<hv>:nodes@ %a@])@ (@[<hv>:sig-tbl@ %a@])@])"
|
||||
(Util.pp_seq ~sep:" " pp_n) (T_tbl.values cc.tbl)
|
||||
(Util.pp_seq ~sep:" " pp_sig_e) (Sig_tbl.to_seq cc.signatures_tbl)
|
||||
|
||||
(* compute up-to-date signature *)
|
||||
let update_sig (s:signature) : Signature.t =
|
||||
CC_types.map_view s
|
||||
~f_f:(fun x->x)
|
||||
~f_t:find_
|
||||
~f_ts:(List.map find_)
|
||||
|
||||
(* find whether the given (parent) term corresponds to some signature
|
||||
in [signatures_] *)
|
||||
let[@inline] find_signature cc (s:signature) : repr option =
|
||||
Sig_tbl.get cc.signatures_tbl s
|
||||
|
||||
let add_signature cc (s:signature) (n:node) : unit =
|
||||
(* add, but only if not present already *)
|
||||
match Sig_tbl.find cc.signatures_tbl s with
|
||||
| exception Not_found ->
|
||||
Log.debugf 15
|
||||
(fun k->k "(@[cc.add-sig@ %a@ ~~> %a@])" Signature.pp s N.pp n);
|
||||
on_backtrack cc (fun () -> Sig_tbl.remove cc.signatures_tbl s);
|
||||
Sig_tbl.add cc.signatures_tbl s n;
|
||||
| r' ->
|
||||
assert (same_class n r');
|
||||
()
|
||||
|
||||
let push_pending cc t : unit =
|
||||
if not @@ N.get_field field_is_pending t then (
|
||||
Log.debugf 5 (fun k->k "(@[<hv1>cc.push_pending@ %a@])" N.pp t);
|
||||
N.set_field field_is_pending true t;
|
||||
Vec.push cc.pending t
|
||||
)
|
||||
|
||||
let push_combine cc t u e : unit =
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[<hv1>cc.push_combine@ %a ~@ %a@ :expl %a@])"
|
||||
N.pp t N.pp u Expl.pp e);
|
||||
Vec.push cc.combine @@ CT_merge (t,u,e)
|
||||
|
||||
(* re-root the explanation tree of the equivalence class of [n]
|
||||
so that it points to [n].
|
||||
postcondition: [n.n_expl = None] *)
|
||||
let rec reroot_expl (cc:t) (n:node): unit =
|
||||
let old_expl = n.n_expl in
|
||||
begin match old_expl with
|
||||
| FL_none -> () (* already root *)
|
||||
| FL_some {next=u; expl=e_n_u} ->
|
||||
reroot_expl cc u;
|
||||
u.n_expl <- FL_some {next=n; expl=e_n_u};
|
||||
n.n_expl <- FL_none;
|
||||
end
|
||||
|
||||
let raise_conflict (cc:t) (acts:sat_actions) (e:conflict): _ =
|
||||
(* clear tasks queue *)
|
||||
Vec.iter (N.set_field field_is_pending false) cc.pending;
|
||||
Vec.clear cc.pending;
|
||||
Vec.clear cc.combine;
|
||||
let c = List.rev_map A.Lit.neg e in
|
||||
acts.Msat.acts_raise_conflict c A.Proof.default
|
||||
|
||||
let[@inline] all_classes cc : repr Sequence.t =
|
||||
T_tbl.values cc.tbl
|
||||
|> Sequence.filter is_root_
|
||||
|
||||
(* TODO: use markers and lockstep iteration instead *)
|
||||
(* distance from [t] to its root in the proof forest *)
|
||||
let[@inline][@unroll 2] rec distance_to_root (n:node): int = match n.n_expl with
|
||||
| FL_none -> 0
|
||||
| FL_some {next=t'; _} -> 1 + distance_to_root t'
|
||||
|
||||
(* TODO: bool flag on nodes + stepwise progress + cleanup *)
|
||||
(* find the closest common ancestor of [a] and [b] in the proof forest *)
|
||||
let find_common_ancestor (a:node) (b:node) : node =
|
||||
let d_a = distance_to_root a in
|
||||
let d_b = distance_to_root b in
|
||||
(* drop [n] nodes in the path from [t] to its root *)
|
||||
let rec drop_ n t =
|
||||
if n=0 then t
|
||||
else match t.n_expl with
|
||||
| FL_none -> assert false
|
||||
| FL_some {next=t'; _} -> drop_ (n-1) t'
|
||||
in
|
||||
(* reduce to the problem where [a] and [b] have the same distance to root *)
|
||||
let a, b =
|
||||
if d_a > d_b then drop_ (d_a-d_b) a, b
|
||||
else if d_a < d_b then a, drop_ (d_b-d_a) b
|
||||
else a, b
|
||||
in
|
||||
(* traverse stepwise until a==b *)
|
||||
let rec aux_same_dist a b =
|
||||
if a==b then a
|
||||
else match a.n_expl, b.n_expl with
|
||||
| FL_none, _ | _, FL_none -> assert false
|
||||
| FL_some {next=a'; _}, FL_some {next=b'; _} -> aux_same_dist a' b'
|
||||
in
|
||||
aux_same_dist a b
|
||||
|
||||
let[@inline] ps_add_obligation (cc:t) a b = Vec.push cc.ps_queue (a,b)
|
||||
let[@inline] ps_add_lit ps l = ps.ps_lits <- l :: ps.ps_lits
|
||||
|
||||
(* TODO: remove *)
|
||||
let ps_clear (cc:t) =
|
||||
cc.ps_lits <- [];
|
||||
Vec.clear cc.ps_queue;
|
||||
()
|
||||
|
||||
let decompose_explain cc (e:explanation): unit =
|
||||
Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e);
|
||||
begin match e with
|
||||
| E_reduction -> ()
|
||||
| E_lit lit -> ps_add_lit cc lit
|
||||
| E_lits l -> List.iter (ps_add_lit cc) l
|
||||
| E_merges l ->
|
||||
(* need to explain each merge in [l] *)
|
||||
List.iter (fun (t,u) -> ps_add_obligation cc t u) l
|
||||
end
|
||||
|
||||
(* explain why [a = parent_a], where [a -> ... -> parent_a] in the
|
||||
proof forest *)
|
||||
let rec explain_along_path ps (a:node) (parent_a:node) : unit =
|
||||
if a!=parent_a then (
|
||||
match a.n_expl with
|
||||
| FL_none -> assert false
|
||||
| FL_some {next=next_a; expl=e_a_b} ->
|
||||
decompose_explain ps e_a_b;
|
||||
(* now prove [next_a = parent_a] *)
|
||||
explain_along_path ps next_a parent_a
|
||||
)
|
||||
|
||||
(* find explanation *)
|
||||
let explain_loop (cc : t) : lit list =
|
||||
while not (Vec.is_empty cc.ps_queue) do
|
||||
let a, b = Vec.pop cc.ps_queue in
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b);
|
||||
assert (N.equal (find_ a) (find_ b));
|
||||
let c = find_common_ancestor a b in
|
||||
explain_along_path cc a c;
|
||||
explain_along_path cc b c;
|
||||
done;
|
||||
cc.ps_lits
|
||||
|
||||
(* TODO: do not use ps_lits anymore *)
|
||||
let explain_eq_n ?(init=[]) cc (n1:node) (n2:node) : lit list =
|
||||
ps_clear cc;
|
||||
cc.ps_lits <- init;
|
||||
ps_add_obligation cc n1 n2;
|
||||
explain_loop cc
|
||||
|
||||
let explain_unfold ?(init=[]) cc (e:explanation) : lit list =
|
||||
ps_clear cc;
|
||||
cc.ps_lits <- init;
|
||||
decompose_explain cc e;
|
||||
explain_loop cc
|
||||
|
||||
(* add [tag] to [n], indicating that [n] is distinct from all the other
|
||||
nodes tagged with [tag]
|
||||
precond: [n] is a representative *)
|
||||
let add_tag_n cc (n:node) (tag:int) (expl:explanation) : unit =
|
||||
assert (is_root_ n);
|
||||
if not (Util.Int_map.mem tag n.n_tags) then (
|
||||
on_backtrack cc
|
||||
(fun () -> n.n_tags <- Util.Int_map.remove tag n.n_tags);
|
||||
n.n_tags <- Util.Int_map.add tag (n,expl) n.n_tags;
|
||||
)
|
||||
|
||||
(* add a term *)
|
||||
let [@inline] rec add_term_rec_ cc t : node =
|
||||
try T_tbl.find cc.tbl t
|
||||
with Not_found -> add_new_term_ cc t
|
||||
|
||||
(* add [t] to [cc] when not present already *)
|
||||
and add_new_term_ cc (t:term) : node =
|
||||
assert (not @@ mem cc t);
|
||||
Log.debugf 15 (fun k->k "(@[cc.add-term@ %a@])" T.pp t);
|
||||
let n = N.make t in
|
||||
(* register sub-terms, add [t] to their parent list, and return the
|
||||
corresponding initial signature *)
|
||||
let sig0 = compute_sig0 cc n in
|
||||
n.n_sig0 <- sig0;
|
||||
(* remove term when we backtrack *)
|
||||
on_backtrack cc
|
||||
(fun () ->
|
||||
Log.debugf 15 (fun k->k "(@[cc.remove-term@ %a@])" T.pp t);
|
||||
T_tbl.remove cc.tbl t);
|
||||
(* add term to the table *)
|
||||
T_tbl.add cc.tbl t n;
|
||||
if CCOpt.is_some sig0 then (
|
||||
(* [n] might be merged with other equiv classes *)
|
||||
push_pending cc n;
|
||||
);
|
||||
n
|
||||
|
||||
(* compute the initial signature of the given node *)
|
||||
and compute_sig0 (self:t) (n:node) : Signature.t option =
|
||||
(* add sub-term to [cc], and register [n] to its parents *)
|
||||
let deref_sub (u:term) : node =
|
||||
let sub = add_term_rec_ self u in
|
||||
(* add [n] to [sub.root]'s parent list *)
|
||||
begin
|
||||
let sub = find_ sub in
|
||||
let old_parents = sub.n_parents in
|
||||
on_backtrack self (fun () -> sub.n_parents <- old_parents);
|
||||
sub.n_parents <- Bag.cons n sub.n_parents;
|
||||
end;
|
||||
sub
|
||||
in
|
||||
let[@inline] return x = Some x in
|
||||
match T.cc_view n.n_term with
|
||||
| Bool _ | Opaque _ -> None
|
||||
| Eq (a,b) ->
|
||||
let a = deref_sub a in
|
||||
let b = deref_sub b in
|
||||
return @@ Eq (a,b)
|
||||
| App_fun (f, args) ->
|
||||
let args = args |> Sequence.map deref_sub |> Sequence.to_list in
|
||||
if args<>[] then (
|
||||
return @@ App_fun (f, args)
|
||||
) else None
|
||||
| App_ho (f, args) ->
|
||||
let args = args |> Sequence.map deref_sub |> Sequence.to_list in
|
||||
return @@ App_ho (deref_sub f, args)
|
||||
| If (a,b,c) ->
|
||||
return @@ If (deref_sub a, deref_sub b, deref_sub c)
|
||||
|
||||
let[@inline] add_term cc t : node = add_term_rec_ cc t
|
||||
let[@inline] add_term' cc t : unit = ignore (add_term_rec_ cc t : node)
|
||||
|
||||
let set_as_lit cc (n:node) (lit:lit) : unit =
|
||||
match n.n_as_lit with
|
||||
| Some _ -> ()
|
||||
| None ->
|
||||
Log.debugf 15 (fun k->k "(@[cc.set-as-lit@ %a@ %a@])" N.pp n A.Lit.pp lit);
|
||||
on_backtrack cc (fun () -> n.n_as_lit <- None);
|
||||
n.n_as_lit <- Some lit
|
||||
|
||||
(* Checks if [ra] and [~into] have compatible normal forms and can
|
||||
be merged w.r.t. the theories.
|
||||
Side effect: also pushes sub-tasks *)
|
||||
let notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit =
|
||||
assert (is_root_ rb);
|
||||
match cc.on_merge with
|
||||
| Some f -> f ra rb e
|
||||
| None -> ()
|
||||
|
||||
let[@inline] n_is_bool (self:t) n : bool =
|
||||
N.equal n (true_ self) || N.equal n (false_ self)
|
||||
|
||||
(* main CC algo: add terms from [pending] to the signature table,
|
||||
check for collisions *)
|
||||
let rec update_tasks (cc:t) (acts:sat_actions) : unit =
|
||||
while not (Vec.is_empty cc.pending && Vec.is_empty cc.combine) do
|
||||
while not @@ Vec.is_empty cc.pending do
|
||||
task_pending_ cc (Vec.pop cc.pending);
|
||||
done;
|
||||
while not @@ Vec.is_empty cc.combine do
|
||||
task_combine_ cc acts (Vec.pop cc.combine);
|
||||
done;
|
||||
done
|
||||
|
||||
and task_pending_ cc (n:node) : unit =
|
||||
N.set_field field_is_pending false n;
|
||||
(* check if some parent collided *)
|
||||
begin match n.n_sig0 with
|
||||
| None -> () (* no-op *)
|
||||
| Some (Eq (a,b)) ->
|
||||
(* if [a=b] is now true, merge [(a=b)] and [true] *)
|
||||
if same_class a b then (
|
||||
let expl = Expl.mk_merges [(a,b)] in
|
||||
push_combine cc n (true_ cc) expl
|
||||
)
|
||||
| Some s0 ->
|
||||
(* update the signature by using [find] on each sub-node *)
|
||||
let s = update_sig s0 in
|
||||
match find_signature cc s with
|
||||
| None ->
|
||||
(* add to the signature table [sig(n) --> n] *)
|
||||
add_signature cc s n
|
||||
| Some u when n == u -> ()
|
||||
| Some u ->
|
||||
(* [t1] and [t2] must be applications of the same symbol to
|
||||
arguments that are pairwise equal *)
|
||||
assert (n != u);
|
||||
let expl =
|
||||
match n.n_sig0, u.n_sig0 with
|
||||
| Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) ->
|
||||
assert (Fun.equal f1 f2);
|
||||
assert (List.length a1 = List.length a2);
|
||||
(* TODO: just use "congruence" as explanation *)
|
||||
Expl.mk_merges @@ List.combine a1 a2
|
||||
| Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) ->
|
||||
assert (List.length a1 = List.length a2);
|
||||
(* TODO: just use "congruence" as explanation *)
|
||||
Expl.mk_merges @@ (f1,f2)::List.combine a1 a2
|
||||
| Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) ->
|
||||
Expl.mk_merges @@ [a1,a2; b1,b2; c1,c2]
|
||||
| _
|
||||
-> assert false
|
||||
in
|
||||
push_combine cc n u expl
|
||||
(* FIXME: when to actually evaluate?
|
||||
eval_pending cc;
|
||||
*)
|
||||
end
|
||||
|
||||
and[@inline] task_combine_ cc acts = function
|
||||
| CT_merge (a,b,e_ab) -> task_merge_ cc acts a b e_ab
|
||||
| CT_distinct (l,tag,e) -> task_distinct_ cc acts l tag e
|
||||
|
||||
(* main CC algo: merge equivalence classes in [st.combine].
|
||||
@raise Exn_unsat if merge fails *)
|
||||
and task_merge_ cc acts a b e_ab : unit =
|
||||
let ra = find_ a in
|
||||
let rb = find_ b in
|
||||
if not @@ N.equal ra rb then (
|
||||
assert (is_root_ ra);
|
||||
assert (is_root_ rb);
|
||||
(* check we're not merging [true] and [false] *)
|
||||
if (N.equal ra (true_ cc) && N.equal rb (false_ cc)) ||
|
||||
(N.equal rb (true_ cc) && N.equal ra (false_ cc)) then (
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[<hv>cc.merge.true_false_conflict@ @[:r1 %a@]@ @[:r2 %a@]@ :e_ab %a@])"
|
||||
N.pp ra N.pp rb Expl.pp e_ab);
|
||||
let lits = explain_unfold cc e_ab in
|
||||
let lits = explain_eq_n ~init:lits cc a ra in
|
||||
let lits = explain_eq_n ~init:lits cc b rb in
|
||||
raise_conflict cc acts lits
|
||||
);
|
||||
(* We will merge [r_from] into [r_into].
|
||||
we try to ensure that [size ra <= size rb] in general, but always
|
||||
keep values as representative *)
|
||||
let r_from, r_into =
|
||||
if n_is_bool cc ra then rb, ra
|
||||
else if n_is_bool cc rb then ra, rb
|
||||
else if size_ ra > size_ rb then rb, ra
|
||||
else ra, rb
|
||||
in
|
||||
(* TODO: instead call micro theories, including "distinct" *)
|
||||
(* update set of tags the new node cannot be equal to *)
|
||||
let new_tags =
|
||||
Util.Int_map.union
|
||||
(fun _i (n1,e1) (n2,e2) ->
|
||||
(* both maps contain same tag [_i]. conflict clause:
|
||||
[e1 & e2 & e_ab] impossible *)
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[<hv>cc.merge.distinct_conflict@ :tag %d@ \
|
||||
@[:r1 %a@ :e1 %a@]@ @[:r2 %a@ :e2 %a@]@ :e_ab %a@])"
|
||||
_i N.pp n1 Expl.pp e1
|
||||
N.pp n2 Expl.pp e2 Expl.pp e_ab);
|
||||
let lits = explain_unfold cc e1 in
|
||||
let lits = explain_unfold ~init:lits cc e2 in
|
||||
let lits = explain_unfold ~init:lits cc e_ab in
|
||||
let lits = explain_eq_n ~init:lits cc a n1 in
|
||||
let lits = explain_eq_n ~init:lits cc b n2 in
|
||||
raise_conflict cc acts lits)
|
||||
ra.n_tags rb.n_tags
|
||||
in
|
||||
(* when merging terms with [true] or [false], possibly propagate them to SAT *)
|
||||
let merge_bool r1 t1 r2 t2 =
|
||||
if N.equal r1 (true_ cc) then (
|
||||
propagate_bools cc acts r2 t2 r1 t1 e_ab true
|
||||
) else if N.equal r1 (false_ cc) then (
|
||||
propagate_bools cc acts r2 t2 r1 t1 e_ab false
|
||||
)
|
||||
in
|
||||
merge_bool ra a rb b;
|
||||
merge_bool rb b ra a;
|
||||
(* perform [union r_from r_into] *)
|
||||
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
|
||||
(* TODO: only iterate on parents of [rb] *)
|
||||
(* TODO: [ra.parents <- ra.parent ++ rb.parents] *)
|
||||
begin
|
||||
(* for each node in [r_from]'s class:
|
||||
- make it point to [r_into]
|
||||
- push it into [st.pending] *)
|
||||
iter_class_ r_from
|
||||
(fun u ->
|
||||
assert (u.n_root == r_from);
|
||||
on_backtrack cc (fun () -> u.n_root <- r_from);
|
||||
u.n_root <- r_into;
|
||||
Bag.to_seq u.n_parents
|
||||
(fun parent -> push_pending cc parent));
|
||||
(* now merge the classes *)
|
||||
let r_into_old_tags = r_into.n_tags in
|
||||
let r_into_old_next = r_into.n_next in
|
||||
let r_from_old_next = r_from.n_next in
|
||||
on_backtrack cc
|
||||
(fun () ->
|
||||
Log.debugf 15
|
||||
(fun k->k "(@[cc.undo_merge@ :from %a :into %a@])"
|
||||
N.pp r_from N.pp r_into);
|
||||
r_into.n_next <- r_into_old_next;
|
||||
r_from.n_next <- r_from_old_next;
|
||||
r_into.n_tags <- r_into_old_tags);
|
||||
r_into.n_tags <- new_tags;
|
||||
(* swap [into.next] and [from.next], merging the classes *)
|
||||
r_into.n_next <- r_from_old_next;
|
||||
r_from.n_next <- r_into_old_next;
|
||||
end;
|
||||
(* update explanations (a -> b), arbitrarily.
|
||||
Note that here we merge the classes by adding a bridge between [a]
|
||||
and [b], not their roots. *)
|
||||
begin
|
||||
reroot_expl cc a;
|
||||
assert (a.n_expl = FL_none);
|
||||
(* on backtracking, link may be inverted, but we delete the one
|
||||
that bridges between [a] and [b] *)
|
||||
on_backtrack cc
|
||||
(fun () -> match a.n_expl, b.n_expl with
|
||||
| FL_some e, _ when N.equal e.next b -> a.n_expl <- FL_none
|
||||
| _, FL_some e when N.equal e.next a -> b.n_expl <- FL_none
|
||||
| _ -> assert false);
|
||||
a.n_expl <- FL_some {next=b; expl=e_ab};
|
||||
end;
|
||||
(* notify listeners of the merge *)
|
||||
notify_merge cc r_from ~into:r_into e_ab;
|
||||
)
|
||||
|
||||
and task_distinct_ cc acts (l:node list) tag expl : unit =
|
||||
let l = List.map (fun n -> n, find_ n) l in
|
||||
let coll =
|
||||
Sequence.diagonal_l l
|
||||
|> Sequence.find_pred (fun ((_,r1),(_,r2)) -> N.equal r1 r2)
|
||||
in
|
||||
begin match coll with
|
||||
| Some ((n1,_r1),(n2,_r2)) ->
|
||||
(* two classes are already equal *)
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[cc.distinct.conflict@ %a = %a@ :expl %a@])" N.pp n1 N.pp
|
||||
n2 Expl.pp expl);
|
||||
let lits = explain_unfold cc expl in
|
||||
raise_conflict cc acts lits
|
||||
| None ->
|
||||
(* put a tag on all equivalence classes, that will make their merge fail *)
|
||||
List.iter (fun (_,n) -> add_tag_n cc n tag expl) l
|
||||
end
|
||||
|
||||
(* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1]
|
||||
in the equiv class of [r1] that is a known literal back to the SAT solver
|
||||
and which is not the one initially merged.
|
||||
We can explain the propagation with [u1 = t1 =e= t2 = r2==bool] *)
|
||||
and propagate_bools cc acts r1 t1 r2 t2 (e_12:explanation) sign : unit =
|
||||
(* explanation for [t1 =e= t2 = r2] *)
|
||||
let half_expl = lazy (
|
||||
let expl = explain_unfold cc e_12 in
|
||||
explain_eq_n ~init:expl cc r2 t2
|
||||
) in
|
||||
iter_class_ r1
|
||||
(fun u1 ->
|
||||
(* propagate if:
|
||||
- [u1] is a proper literal
|
||||
- [t2 != r2], because that can only happen
|
||||
after an explicit merge (no way to obtain that by propagation)
|
||||
*)
|
||||
match N.as_lit u1 with
|
||||
| Some lit when not (N.equal r2 t2) ->
|
||||
let lit = if sign then lit else A.Lit.neg lit in (* apply sign *)
|
||||
Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" A.Lit.pp lit);
|
||||
(* complete explanation with the [u1=t1] chunk *)
|
||||
let expl = explain_eq_n ~init:(Lazy.force half_expl) cc u1 t1 in
|
||||
let reason = Msat.Consequence (expl, A.Proof.default) in
|
||||
acts.Msat.acts_propagate lit reason
|
||||
| _ -> ())
|
||||
|
||||
let check_invariants_ (cc:t) =
|
||||
Log.debug 5 "(cc.check-invariants)";
|
||||
Log.debugf 15 (fun k-> k "%a" pp_full cc);
|
||||
assert (T.equal (T.bool cc.tst true) (true_ cc).n_term);
|
||||
assert (T.equal (T.bool cc.tst false) (false_ cc).n_term);
|
||||
assert (not @@ same_class (true_ cc) (false_ cc));
|
||||
assert (Vec.is_empty cc.combine);
|
||||
assert (Vec.is_empty cc.pending);
|
||||
(* check that subterms are internalized *)
|
||||
T_tbl.iter
|
||||
(fun t n ->
|
||||
assert (T.equal t n.n_term);
|
||||
assert (not @@ N.get_field field_is_pending n);
|
||||
assert (N.equal n.n_root n.n_next.n_root);
|
||||
(* check proper signature.
|
||||
note that some signatures in the sig table can be obsolete (they
|
||||
were not removed) but there must be a valid, up-to-date signature for
|
||||
each term *)
|
||||
begin match CCOpt.map update_sig n.n_sig0 with
|
||||
| None -> ()
|
||||
| Some s ->
|
||||
Log.debugf 15 (fun k->k "(@[cc.check-sig@ %a@ :sig %a@])" T.pp t Signature.pp s);
|
||||
(* add, but only if not present already *)
|
||||
begin match Sig_tbl.find cc.signatures_tbl s with
|
||||
| exception Not_found -> assert false
|
||||
| repr_s -> assert (same_class n repr_s)
|
||||
end
|
||||
end;
|
||||
)
|
||||
cc.tbl;
|
||||
()
|
||||
|
||||
let[@inline] check_invariants (cc:t) : unit =
|
||||
if Util._CHECK_INVARIANTS then check_invariants_ cc
|
||||
|
||||
let add_seq cc seq =
|
||||
seq (fun t -> ignore @@ add_term_rec_ cc t);
|
||||
()
|
||||
|
||||
let[@inline] push_level (self:t) : unit =
|
||||
Backtrack_stack.push_level self.undo
|
||||
|
||||
let pop_levels (self:t) n : unit =
|
||||
Vec.iter (N.set_field field_is_pending false) self.pending;
|
||||
Vec.clear self.pending;
|
||||
Vec.clear self.combine;
|
||||
Log.debugf 15
|
||||
(fun k->k "(@[cc.pop-levels %d@ :n-lvls %d@])" n (Backtrack_stack.n_levels self.undo));
|
||||
Backtrack_stack.pop_levels self.undo n ~f:(fun f -> f());
|
||||
()
|
||||
|
||||
(* assert that this boolean literal holds.
|
||||
if a lit is [= a b], merge [a] and [b];
|
||||
if it's [distinct a1…an], make them distinct, etc. etc. *)
|
||||
let assert_lit cc lit : unit =
|
||||
let t = A.Lit.term lit in
|
||||
Log.debugf 5 (fun k->k "(@[cc.assert_lit@ %a@])" A.Lit.pp lit);
|
||||
let sign = A.Lit.sign lit in
|
||||
begin match T.cc_view t with
|
||||
| Eq (a,b) when sign ->
|
||||
(* merge [a] and [b] *)
|
||||
let a = add_term cc a in
|
||||
let b = add_term cc b in
|
||||
push_combine cc a b (Expl.mk_lit lit)
|
||||
| _ ->
|
||||
(* equate t and true/false *)
|
||||
let rhs = if sign then true_ cc else false_ cc in
|
||||
let n = add_term cc t in
|
||||
(* TODO: ensure that this is O(1).
|
||||
basically, just have [n] point to true/false and thus acquire
|
||||
the corresponding value, so its superterms (like [ite]) can evaluate
|
||||
properly *)
|
||||
push_combine cc n rhs (Expl.mk_lit lit)
|
||||
end
|
||||
|
||||
let[@inline] assert_lits cc lits : unit =
|
||||
Sequence.iter (assert_lit cc) lits
|
||||
|
||||
let assert_eq cc t1 t2 (e:lit list) : unit =
|
||||
let expl = Expl.mk_lits e in
|
||||
let n1 = add_term cc t1 in
|
||||
let n2 = add_term cc t2 in
|
||||
push_combine cc n1 n2 expl
|
||||
|
||||
let assert_distinct cc (l:term list) ~neq (lit:lit) : unit =
|
||||
assert (match l with[] | [_] -> false | _ -> true);
|
||||
assert false
|
||||
(* FIXME
|
||||
let tag = Term.id neq in
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[cc.assert_distinct@ (@[%a@])@ :tag %d@])" (Util.pp_list Term.pp) l tag);
|
||||
let l = List.map (add cc) l in
|
||||
Vec.push cc.combine @@ CT_distinct (l, tag, Expl.lit lit)
|
||||
*)
|
||||
|
||||
let create ?on_merge ?(size=`Big) (tst:term_state) : t =
|
||||
let size = match size with `Small -> 128 | `Big -> 2048 in
|
||||
let rec cc = {
|
||||
tst;
|
||||
tbl = T_tbl.create size;
|
||||
signatures_tbl = Sig_tbl.create size;
|
||||
on_merge;
|
||||
pending=Vec.create();
|
||||
combine=Vec.create();
|
||||
ps_lits=[];
|
||||
undo=Backtrack_stack.create();
|
||||
ps_queue=Vec.create();
|
||||
true_;
|
||||
false_;
|
||||
} and true_ = lazy (
|
||||
add_term cc (T.bool tst true)
|
||||
) and false_ = lazy (
|
||||
add_term cc (T.bool tst false)
|
||||
)
|
||||
in
|
||||
ignore (Lazy.force true_ : node);
|
||||
ignore (Lazy.force false_ : node);
|
||||
cc
|
||||
|
||||
let[@inline] find_t cc t : repr =
|
||||
let n = T_tbl.find cc.tbl t in
|
||||
find_ n
|
||||
|
||||
let[@inline] check cc acts : unit =
|
||||
Log.debug 5 "(cc.check)";
|
||||
update_tasks cc acts
|
||||
|
||||
(* model: map each uninterpreted equiv class to some ID *)
|
||||
let mk_model (cc:t) (m:A.Model.t) : A.Model.t =
|
||||
let module Model = A.Model in
|
||||
let module Value = A.Value in
|
||||
Log.debugf 15 (fun k->k "(@[cc.mk-model@ %a@])" pp_full cc);
|
||||
let t_tbl = N_tbl.create 32 in
|
||||
(* populate [repr -> value] table *)
|
||||
T_tbl.values cc.tbl
|
||||
(fun r ->
|
||||
if is_root_ r then (
|
||||
(* find a value in the class, if any *)
|
||||
let v =
|
||||
iter_class_ r
|
||||
|> Sequence.find_map (fun n -> Model.eval m n.n_term)
|
||||
in
|
||||
let v = match v with
|
||||
| Some v -> v
|
||||
| None ->
|
||||
if same_class r (true_ cc) then Value.true_
|
||||
else if same_class r (false_ cc) then Value.false_
|
||||
else Value.fresh r.n_term
|
||||
in
|
||||
N_tbl.add t_tbl r v
|
||||
));
|
||||
(* now map every term to its representative's value *)
|
||||
let pairs =
|
||||
T_tbl.values cc.tbl
|
||||
|> Sequence.map
|
||||
(fun n ->
|
||||
let r = find_ n in
|
||||
let v =
|
||||
try N_tbl.find t_tbl r
|
||||
with Not_found ->
|
||||
Error.errorf "didn't allocate a value for repr %a" N.pp r
|
||||
in
|
||||
n.n_term, v)
|
||||
in
|
||||
let m = Sequence.fold (fun m (t,v) -> Model.add t v m) m pairs in
|
||||
Log.debugf 5 (fun k->k "(@[cc.model@ %a@])" Model.pp m);
|
||||
m
|
||||
end
|
||||
14
src/cc/Congruence_closure.mli
Normal file
14
src/cc/Congruence_closure.mli
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
(** {2 Congruence Closure} *)
|
||||
|
||||
module type ARG = Congruence_closure_intf.ARG
|
||||
module type S = Congruence_closure_intf.S
|
||||
|
||||
type payload = Congruence_closure_intf.payload = ..
|
||||
|
||||
module Make(A: ARG)
|
||||
: S with type term = A.Term.t
|
||||
and type lit = A.Lit.t
|
||||
and type fun_ = A.Fun.t
|
||||
and type term_state = A.Term.state
|
||||
and type proof = A.Proof.t
|
||||
and type model = A.Model.t
|
||||
136
src/cc/Congruence_closure_intf.ml
Normal file
136
src/cc/Congruence_closure_intf.ml
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
|
||||
module type ARG = CC_types.FULL
|
||||
|
||||
(** Theory-extensible payloads in the equivalence classes *)
|
||||
type payload = ..
|
||||
|
||||
module type S = sig
|
||||
type term_state
|
||||
type term
|
||||
type fun_
|
||||
type lit
|
||||
type proof
|
||||
type model
|
||||
|
||||
(** Actions available to the theory *)
|
||||
type sat_actions = (Msat.void, lit, Msat.void, proof) Msat.acts
|
||||
|
||||
type t
|
||||
(** Global state of the congruence closure *)
|
||||
|
||||
|
||||
(** An equivalence class is a set of terms that are currently equal
|
||||
in the partial model built by the solver.
|
||||
The class is represented by a collection of nodes, one of which is
|
||||
distinguished and is called the "representative".
|
||||
|
||||
All information pertaining to the whole equivalence class is stored
|
||||
in this representative's node.
|
||||
|
||||
When two classes become equal (are "merged"), one of the two
|
||||
representatives is picked as the representative of the new class.
|
||||
The new class contains the union of the two old classes' nodes.
|
||||
|
||||
We also allow theories to store additional information in the
|
||||
representative. This information can be used when two classes are
|
||||
merged, to detect conflicts and solve equations à la Shostak.
|
||||
*)
|
||||
module N : sig
|
||||
type t
|
||||
|
||||
val term : t -> term
|
||||
val equal : t -> t -> bool
|
||||
val hash : t -> int
|
||||
val pp : t Fmt.printer
|
||||
|
||||
type nonrec payload = payload = ..
|
||||
|
||||
val payload_find: f:(payload -> 'a option) -> t -> 'a option
|
||||
|
||||
val payload_pred: f:(payload -> bool) -> t -> bool
|
||||
|
||||
val set_payload : ?can_erase:(payload -> bool) -> t -> payload -> unit
|
||||
(** Add given payload
|
||||
@param can_erase if provided, checks whether an existing value
|
||||
is to be replaced instead of adding a new entry *)
|
||||
end
|
||||
|
||||
module Expl : sig
|
||||
type t
|
||||
val pp : t Fmt.printer
|
||||
end
|
||||
|
||||
type node = N.t
|
||||
(** A node of the congruence closure *)
|
||||
|
||||
type repr = N.t
|
||||
(** Node that is currently a representative *)
|
||||
|
||||
type explanation = Expl.t
|
||||
|
||||
type conflict = lit list
|
||||
|
||||
(* TODO micro theories as parameters *)
|
||||
val create :
|
||||
?on_merge:(repr -> repr -> explanation -> unit) ->
|
||||
?size:[`Small | `Big] ->
|
||||
term_state ->
|
||||
t
|
||||
(** Create a new congruence closure.
|
||||
@param on_merge callback to be called on every merge
|
||||
*)
|
||||
|
||||
val find : t -> node -> repr
|
||||
(** Current representative *)
|
||||
|
||||
val add_term : t -> term -> node
|
||||
(** Add the term to the congruence closure, if not present already.
|
||||
Will be backtracked. *)
|
||||
|
||||
val set_as_lit : t -> N.t -> lit -> unit
|
||||
(** map the given node to a literal. *)
|
||||
|
||||
val add_term' : t -> term -> unit
|
||||
(** Same as {!add_term} but ignore the result *)
|
||||
|
||||
val find_t : t -> term -> repr
|
||||
(** Current representative of the term.
|
||||
@raise Not_found if the term is not already {!add}-ed. *)
|
||||
|
||||
val add_seq : t -> term Sequence.t -> unit
|
||||
(** Add a sequence of terms to the congruence closure *)
|
||||
|
||||
val all_classes : t -> repr Sequence.t
|
||||
(** All current classes *)
|
||||
|
||||
val assert_lit : t -> lit -> unit
|
||||
(** Given a literal, assume it in the congruence closure and propagate
|
||||
its consequences. Will be backtracked. *)
|
||||
|
||||
val assert_lits : t -> lit Sequence.t -> unit
|
||||
|
||||
val assert_eq : t -> term -> term -> lit list -> unit
|
||||
(** merge the given terms with some explanations *)
|
||||
|
||||
val assert_distinct : t -> term list -> neq:term -> lit -> unit
|
||||
(** [assert_distinct l ~neq:u e] asserts all elements of [l] are distinct
|
||||
because [lit] is true
|
||||
precond: [u = distinct l] *)
|
||||
|
||||
val check : t -> sat_actions -> unit
|
||||
(** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc.
|
||||
Will use the [sat_actions] to propagate literals, declare conflicts, etc. *)
|
||||
|
||||
val push_level : t -> unit
|
||||
|
||||
val pop_levels : t -> int -> unit
|
||||
|
||||
val mk_model : t -> model -> model
|
||||
(** Enrich a model by mapping terms to their representative's value,
|
||||
if any. Otherwise map the representative to a fresh value *)
|
||||
|
||||
(**/**)
|
||||
val check_invariants : t -> unit
|
||||
val pp_full : t Fmt.printer
|
||||
(**/**)
|
||||
end
|
||||
|
|
@ -1,23 +1,34 @@
|
|||
|
||||
module H = CCHash
|
||||
|
||||
type ('f, 't, 'ts) view = ('f, 't, 'ts) Mini_cc_intf.view =
|
||||
| Bool of bool
|
||||
| App of 'f * 'ts
|
||||
| If of 't * 't * 't
|
||||
|
||||
type res = Mini_cc_intf.res =
|
||||
type res =
|
||||
| Sat
|
||||
| Unsat
|
||||
|
||||
module type ARG = Mini_cc_intf.ARG
|
||||
module type S = Mini_cc_intf.S
|
||||
module type TERM = CC_types.TERM
|
||||
|
||||
module type S = sig
|
||||
type term
|
||||
type fun_
|
||||
type term_state
|
||||
|
||||
type t
|
||||
|
||||
val create : term_state -> t
|
||||
|
||||
val add_lit : t -> term -> bool -> unit
|
||||
val distinct : t -> term list -> unit
|
||||
|
||||
val check : t -> res
|
||||
end
|
||||
|
||||
|
||||
module Make(A: TERM) = struct
|
||||
open CC_types
|
||||
|
||||
module Make(A: ARG) = struct
|
||||
module Fun = A.Fun
|
||||
module T = A.Term
|
||||
type fun_ = A.Fun.t
|
||||
type term = T.t
|
||||
type term_state = A.Term.state
|
||||
|
||||
module T_tbl = CCHashtbl.Make(T)
|
||||
|
||||
|
|
@ -65,49 +76,77 @@ module Make(A: ARG) = struct
|
|||
let equal (s1:t) s2 : bool =
|
||||
match s1, s2 with
|
||||
| Bool b1, Bool b2 -> b1=b2
|
||||
| App (f1,[]), App (f2,[]) -> Fun.equal f1 f2
|
||||
| App (f1,l1), App (f2,l2) ->
|
||||
| App_fun (f1,[]), App_fun (f2,[]) -> Fun.equal f1 f2
|
||||
| App_fun (f1,l1), App_fun (f2,l2) ->
|
||||
Fun.equal f1 f2 && CCList.equal Node.equal l1 l2
|
||||
| App_ho (f1,l1), App_ho (f2,l2) ->
|
||||
Node.equal f1 f2 && CCList.equal Node.equal l1 l2
|
||||
| If (a1,b1,c1), If (a2,b2,c2) ->
|
||||
Node.equal a1 a2 && Node.equal b1 b2 && Node.equal c1 c2
|
||||
| Bool _, _ | App _, _ | If _, _
|
||||
| Eq (a1,b1), Eq (a2,b2) ->
|
||||
Node.equal a1 a2 && Node.equal b1 b2
|
||||
| Opaque u1, Opaque u2 -> Node.equal u1 u2
|
||||
| Bool _, _ | App_fun _, _ | App_ho _, _ | If _, _
|
||||
| Eq _, _ | Opaque _, _
|
||||
-> false
|
||||
|
||||
let hash (s:t) : int =
|
||||
let module H = CCHash in
|
||||
match s with
|
||||
| Bool b -> H.combine2 10 (H.bool b)
|
||||
| App (f, l) -> H.combine3 20 (Fun.hash f) (H.list Node.hash l)
|
||||
| If (a,b,c) -> H.combine4 30 (Node.hash a)(Node.hash b)(Node.hash c)
|
||||
| App_fun (f, l) -> H.combine3 20 (Fun.hash f) (H.list Node.hash l)
|
||||
| App_ho (f, l) -> H.combine3 30 (Node.hash f) (H.list Node.hash l)
|
||||
| Eq (a,b) -> H.combine3 40 (Node.hash a) (Node.hash b)
|
||||
| Opaque u -> H.combine2 50 (Node.hash u)
|
||||
| If (a,b,c) -> H.combine4 60 (Node.hash a)(Node.hash b)(Node.hash c)
|
||||
|
||||
let pp out = function
|
||||
| Bool b -> Fmt.bool out b
|
||||
| App (f, []) -> Fun.pp out f
|
||||
| App (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Fun.pp f (Util.pp_list Node.pp) l
|
||||
| App_fun (f, []) -> Fun.pp out f
|
||||
| App_fun (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Fun.pp f (Util.pp_list Node.pp) l
|
||||
| App_ho (f, []) -> Node.pp out f
|
||||
| App_ho (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Node.pp f (Util.pp_list Node.pp) l
|
||||
| Opaque t -> Node.pp out t
|
||||
| Eq (a,b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" Node.pp a Node.pp b
|
||||
| If (a,b,c) -> Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" Node.pp a Node.pp b Node.pp c
|
||||
end
|
||||
|
||||
module Sig_tbl = CCHashtbl.Make(Signature)
|
||||
|
||||
type t = {
|
||||
mutable ok: bool; (* unsat? *)
|
||||
tbl: node T_tbl.t;
|
||||
sig_tbl: node Sig_tbl.t;
|
||||
combine: (node * node) Vec.t;
|
||||
pending: node Vec.t; (* refresh signature *)
|
||||
distinct: node list ref Vec.t; (* disjoint sets *)
|
||||
true_: node;
|
||||
false_: node;
|
||||
}
|
||||
|
||||
let create() : t =
|
||||
{ tbl= T_tbl.create 128;
|
||||
let create tst : t =
|
||||
let true_ = T.bool tst true in
|
||||
let false_ = T.bool tst false in
|
||||
let self = {
|
||||
ok=true;
|
||||
tbl= T_tbl.create 128;
|
||||
sig_tbl=Sig_tbl.create 128;
|
||||
combine=Vec.create();
|
||||
pending=Vec.create();
|
||||
distinct=Vec.create();
|
||||
}
|
||||
true_=Node.make true_;
|
||||
false_=Node.make false_;
|
||||
} in
|
||||
T_tbl.add self.tbl true_ self.true_;
|
||||
T_tbl.add self.tbl false_ self.false_;
|
||||
self
|
||||
|
||||
let sub_ t k : unit =
|
||||
match T.view t with
|
||||
| Bool _ -> ()
|
||||
| App (_, args) -> args k
|
||||
match T.cc_view t with
|
||||
| Bool _ | Opaque _ -> ()
|
||||
| App_fun (_, args) -> args k
|
||||
| App_ho (f, args) -> k f; args k
|
||||
| Eq (a,b) -> k a; k b
|
||||
| If(a,b,c) -> k a; k b; k c
|
||||
|
||||
let rec add_t (self:t) (t:term) : node =
|
||||
|
|
@ -152,8 +191,37 @@ module Make(A: ARG) = struct
|
|||
if has_dups !r then raise_notrace E_unsat)
|
||||
self.distinct
|
||||
|
||||
let compute_sig (self:t) (n:node) : Signature.t option =
|
||||
let[@inline] return x = Some x in
|
||||
match T.cc_view n.n_t with
|
||||
| Bool _ | Opaque _ -> None
|
||||
| Eq (a,b) ->
|
||||
let a = find_t_ self a in
|
||||
let b = find_t_ self b in
|
||||
return @@ Eq (a,b)
|
||||
| App_fun (f, args) ->
|
||||
let args = args |> Sequence.map (find_t_ self) |> Sequence.to_list in
|
||||
if args<>[] then (
|
||||
return @@ App_fun (f, args)
|
||||
) else None
|
||||
| App_ho (f, args) ->
|
||||
let args = args |> Sequence.map (find_t_ self) |> Sequence.to_list in
|
||||
return @@ App_ho (find_t_ self f, args)
|
||||
| If (a,b,c) ->
|
||||
return @@ If(find_t_ self a, find_t_ self b, find_t_ self c)
|
||||
|
||||
let update_sig_ (self:t) (n: node) : unit =
|
||||
let aux s =
|
||||
match compute_sig self n with
|
||||
| None -> ()
|
||||
| Some (Eq (a,b)) ->
|
||||
if Node.equal a b then (
|
||||
(* reduce to [true] *)
|
||||
let n2 = self.true_ in
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[minicc.congruence-by-eq@ %a@ %a@])" Node.pp n Node.pp n2);
|
||||
Vec.push self.combine (n,n2)
|
||||
)
|
||||
| Some s ->
|
||||
Log.debugf 5 (fun k->k "(@[minicc.update-sig@ %a@])" Signature.pp s);
|
||||
match Sig_tbl.find self.sig_tbl s with
|
||||
| n2 when Node.equal n n2 -> ()
|
||||
|
|
@ -164,23 +232,28 @@ module Make(A: ARG) = struct
|
|||
Vec.push self.combine (n,n2)
|
||||
| exception Not_found ->
|
||||
Sig_tbl.add self.sig_tbl s n
|
||||
in
|
||||
match T.view n.n_t with
|
||||
| Bool _ -> ()
|
||||
| App (f, args) ->
|
||||
let args = args |> Sequence.map (find_t_ self) |> Sequence.to_list in
|
||||
aux @@ App (f, args)
|
||||
| If (a,b,c) -> aux @@ If(find_t_ self a, find_t_ self b, find_t_ self c)
|
||||
|
||||
let[@inline] is_bool self n = Node.equal self.true_ n || Node.equal self.false_ n
|
||||
|
||||
(* merge the two classes *)
|
||||
let merge_ self (n1,n2) : unit =
|
||||
let n1 = find_ n1 in
|
||||
let n2 = find_ n2 in
|
||||
if not @@ Node.equal n1 n2 then (
|
||||
(* merge into largest class *)
|
||||
let n1, n2 = if Node.size n1 > Node.size n2 then n1, n2 else n2, n1 in
|
||||
(* merge into largest class, or into a boolean *)
|
||||
let n1, n2 =
|
||||
if is_bool self n1 then n1, n2
|
||||
else if is_bool self n2 then n2, n1
|
||||
else if Node.size n1 > Node.size n2 then n1, n2
|
||||
else n2, n1 in
|
||||
Log.debugf 5 (fun k->k "(@[minicc.merge@ :into %a@ %a@])" Node.pp n1 Node.pp n2);
|
||||
|
||||
if is_bool self n1 && is_bool self n2 then (
|
||||
Log.debugf 5 (fun k->k "(minicc.conflict.merge-true-false)");
|
||||
self.ok <- false;
|
||||
raise E_unsat
|
||||
);
|
||||
|
||||
List.iter (Vec.push self.pending) n2.n_parents; (* will change signature *)
|
||||
|
||||
(* merge parent lists *)
|
||||
|
|
@ -191,9 +264,13 @@ module Make(A: ARG) = struct
|
|||
Node.iter_cls n2 (fun n -> n.n_root <- n1);
|
||||
)
|
||||
|
||||
let check_ok_ self =
|
||||
if not self.ok then raise_notrace E_unsat
|
||||
|
||||
(* fixpoint of the congruence closure *)
|
||||
let fixpoint (self:t) : unit =
|
||||
while not (Vec.is_empty self.pending && Vec.is_empty self.combine) do
|
||||
check_ok_ self;
|
||||
while not @@ Vec.is_empty self.pending do
|
||||
update_sig_ self @@ Vec.pop self.pending
|
||||
done;
|
||||
|
|
@ -205,10 +282,17 @@ module Make(A: ARG) = struct
|
|||
|
||||
(* API *)
|
||||
|
||||
let merge (self:t) t1 t2 : unit =
|
||||
let n1 = add_t self t1 in
|
||||
let n2 = add_t self t2 in
|
||||
Vec.push self.combine (n1,n2)
|
||||
let add_lit (self:t) (p:T.t) (sign:bool) : unit =
|
||||
match T.cc_view p with
|
||||
| Eq (t1,t2) when sign ->
|
||||
let n1 = add_t self t1 in
|
||||
let n2 = add_t self t2 in
|
||||
Vec.push self.combine (n1,n2)
|
||||
| _ ->
|
||||
(* just merge with true/false *)
|
||||
let n = add_t self p in
|
||||
let n2 = if sign then self.true_ else self.false_ in
|
||||
Vec.push self.combine (n,n2)
|
||||
|
||||
let distinct (self:t) l =
|
||||
begin match l with
|
||||
|
|
@ -220,6 +304,8 @@ module Make(A: ARG) = struct
|
|||
|
||||
let check (self:t) : res =
|
||||
try fixpoint self; Sat
|
||||
with E_unsat -> Unsat
|
||||
with E_unsat ->
|
||||
self.ok <- false;
|
||||
Unsat
|
||||
|
||||
end
|
||||
36
src/cc/Mini_cc.mli
Normal file
36
src/cc/Mini_cc.mli
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
|
||||
(** {1 Mini congruence closure}
|
||||
|
||||
This implementation is as simple as possible, and doesn't provide
|
||||
backtracking, theories, or explanations.
|
||||
It just decides the satisfiability of a set of (dis)equations.
|
||||
*)
|
||||
|
||||
type res =
|
||||
| Sat
|
||||
| Unsat
|
||||
|
||||
module type TERM = CC_types.TERM
|
||||
|
||||
module type S = sig
|
||||
type term
|
||||
type fun_
|
||||
type term_state
|
||||
|
||||
type t
|
||||
|
||||
val create : term_state -> t
|
||||
|
||||
val add_lit : t -> term -> bool -> unit
|
||||
(** [add_lit cc p sign] asserts that [p=sign] *)
|
||||
|
||||
val distinct : t -> term list -> unit
|
||||
(** [distinct cc l] asserts that all terms in [l] are distinct *)
|
||||
|
||||
val check : t -> res
|
||||
end
|
||||
|
||||
module Make(A: TERM)
|
||||
: S with type term = A.Term.t
|
||||
and type fun_ = A.Fun.t
|
||||
and type term_state = A.Term.state
|
||||
23
src/cc/Sidekick_cc.ml
Normal file
23
src/cc/Sidekick_cc.ml
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
|
||||
type ('f, 't, 'ts) view = ('f, 't, 'ts) CC_types.view =
|
||||
| Bool of bool
|
||||
| App_fun of 'f * 'ts
|
||||
| App_ho of 't * 'ts
|
||||
| If of 't * 't * 't
|
||||
| Eq of 't * 't
|
||||
| Opaque of 't (* do not enter *)
|
||||
|
||||
type payload = Congruence_closure.payload = ..
|
||||
|
||||
module CC_types = CC_types
|
||||
|
||||
(** Parameter for the congruence closure *)
|
||||
module type TERM_LIT = CC_types.TERM_LIT
|
||||
module type FULL = CC_types.FULL
|
||||
module type S = Congruence_closure.S
|
||||
|
||||
module Mini_cc = Mini_cc
|
||||
module Congruence_closure = Congruence_closure
|
||||
|
||||
module Make = Congruence_closure.Make
|
||||
|
||||
10
src/cc/dune
Normal file
10
src/cc/dune
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
|
||||
|
||||
(library
|
||||
(name Sidekick_cc)
|
||||
(public_name sidekick.cc)
|
||||
(libraries containers containers.data msat sequence sidekick.util)
|
||||
(flags :standard -warn-error -a+8
|
||||
-color always -safe-string -short-paths -open Sidekick_util)
|
||||
(ocamlopt_flags :standard -O3 -color always
|
||||
-unbox-closures -unbox-closures-factor 20))
|
||||
18
src/smt/CC.ml
Normal file
18
src/smt/CC.ml
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
|
||||
module Arg = struct
|
||||
module Fun = Cst
|
||||
module Term = Term
|
||||
module Lit = Lit
|
||||
module Value = Value
|
||||
module Ty = Ty
|
||||
module Model = Model
|
||||
module Proof = struct
|
||||
type t = Solver_types.proof
|
||||
let pp = Solver_types.pp_proof
|
||||
let default = Solver_types.Proof_default
|
||||
end
|
||||
end
|
||||
|
||||
include Sidekick_cc.Make(Arg)
|
||||
|
||||
module Mini_cc = Sidekick_cc.Mini_cc.Make(Arg)
|
||||
13
src/smt/CC.mli
Normal file
13
src/smt/CC.mli
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
|
||||
include Sidekick_cc.S
|
||||
with type term = Term.t
|
||||
and type model = Model.t
|
||||
and type lit = Lit.t
|
||||
and type fun_ = Cst.t
|
||||
and type term_state = Term.state
|
||||
and type proof = Solver_types.proof
|
||||
|
||||
module Mini_cc : Sidekick_cc.Mini_cc.S
|
||||
with type term = Term.t
|
||||
and type fun_ = Cst.t
|
||||
and type term_state = Term.state
|
||||
|
|
@ -1,763 +0,0 @@
|
|||
|
||||
open Solver_types
|
||||
|
||||
module N = Eq_class
|
||||
|
||||
type node = N.t
|
||||
type repr = N.t
|
||||
type conflict = Theory.conflict
|
||||
|
||||
module T_arg = struct
|
||||
module Fun = Cst
|
||||
module Term = struct
|
||||
include Term
|
||||
let view = cc_view
|
||||
end
|
||||
end
|
||||
module Mini_cc = Mini_cc.Make(T_arg)
|
||||
|
||||
(** A signature is a shallow term shape where immediate subterms
|
||||
are representative *)
|
||||
module Signature = struct
|
||||
type t = node Term.view
|
||||
include Term_cell.Make_eq(N)
|
||||
end
|
||||
|
||||
module Sig_tbl = CCHashtbl.Make(Signature)
|
||||
|
||||
type explanation_thunk = explanation lazy_t
|
||||
|
||||
type combine_task =
|
||||
| CT_merge of node * node * explanation_thunk
|
||||
| CT_distinct of node list * int * explanation
|
||||
|
||||
type t = {
|
||||
tst: Term.state;
|
||||
tbl: node Term.Tbl.t;
|
||||
(* internalization [term -> node] *)
|
||||
signatures_tbl : node Sig_tbl.t;
|
||||
(* map a signature to the corresponding node in some equivalence class.
|
||||
A signature is a [term_cell] in which every immediate subterm
|
||||
that participates in the congruence/evaluation relation
|
||||
is normalized (i.e. is its own representative).
|
||||
The critical property is that all members of an equivalence class
|
||||
that have the same "shape" (including head symbol)
|
||||
have the same signature *)
|
||||
pending: node Vec.t;
|
||||
combine: combine_task Vec.t;
|
||||
undo: (unit -> unit) Backtrack_stack.t;
|
||||
on_merge: (repr -> repr -> explanation -> unit) option;
|
||||
mutable ps_lits: Lit.Set.t; (* TODO: thread it around instead? *)
|
||||
(* proof state *)
|
||||
ps_queue: (node*node) Vec.t;
|
||||
(* pairs to explain *)
|
||||
true_ : node lazy_t;
|
||||
false_ : node lazy_t;
|
||||
}
|
||||
(* TODO: an additional union-find to keep track, for each term,
|
||||
of the terms they are known to be equal to, according
|
||||
to the current explanation. That allows not to prove some equality
|
||||
several times.
|
||||
See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *)
|
||||
|
||||
let[@inline] is_root_ (n:node) : bool = n.n_root == n
|
||||
let[@inline] size_ (r:repr) = r.n_size
|
||||
let[@inline] true_ cc = Lazy.force cc.true_
|
||||
let[@inline] false_ cc = Lazy.force cc.false_
|
||||
|
||||
let[@inline] on_backtrack cc f : unit =
|
||||
Backtrack_stack.push_if_nonzero_level cc.undo f
|
||||
|
||||
(* check if [t] is in the congruence closure.
|
||||
Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *)
|
||||
let[@inline] mem (cc:t) (t:term): bool = Term.Tbl.mem cc.tbl t
|
||||
|
||||
(* find representative, recursively *)
|
||||
let rec find_rec cc (n:node) : repr =
|
||||
if n==n.n_root then (
|
||||
n
|
||||
) else (
|
||||
(* TODO: path compression, assuming backtracking restores equiv classes
|
||||
properly *)
|
||||
let root = find_rec cc n.n_root in
|
||||
root
|
||||
)
|
||||
|
||||
(* traverse the equivalence class of [n] *)
|
||||
let iter_class_ (n:node) : node Sequence.t =
|
||||
fun yield ->
|
||||
let rec aux u =
|
||||
yield u;
|
||||
if u.n_next != n then aux u.n_next
|
||||
in
|
||||
aux n
|
||||
|
||||
(* get term that should be there *)
|
||||
let[@inline] get_ cc (t:term) : node =
|
||||
try Term.Tbl.find cc.tbl t
|
||||
with Not_found ->
|
||||
Log.debugf 1 (fun k->k "(@[<hv1>cc.error@ :missing-term %a@])" Term.pp t);
|
||||
assert false
|
||||
|
||||
(* non-recursive, inlinable function for [find] *)
|
||||
let[@inline] find st (n:node) : repr =
|
||||
if n == n.n_root then n else find_rec st n
|
||||
|
||||
let[@inline] find_tn cc (t:term) : repr = get_ cc t |> find cc
|
||||
|
||||
let[@inline] same_class cc (n1:node)(n2:node): bool =
|
||||
N.equal (find cc n1) (find cc n2)
|
||||
|
||||
(* print full state *)
|
||||
let pp_full out (cc:t) : unit =
|
||||
let pp_next out n =
|
||||
Fmt.fprintf out "@ :next %a" N.pp n.n_next in
|
||||
let pp_root out n =
|
||||
if is_root_ n then Fmt.string out " :is-root" else Fmt.fprintf out "@ :root %a" N.pp n.n_root in
|
||||
let pp_expl out n = match n.n_expl with
|
||||
| E_none -> ()
|
||||
| E_some e ->
|
||||
Fmt.fprintf out " (@[:forest %a :expl %a@])" N.pp e.next Explanation.pp e.expl
|
||||
in
|
||||
let pp_n out n =
|
||||
Fmt.fprintf out "(@[%a%a%a%a@])" Term.pp n.n_term pp_root n pp_next n pp_expl n
|
||||
and pp_sig_e out (s,n) =
|
||||
Fmt.fprintf out "(@[<1>%a@ <--> %a%a@])" Signature.pp s N.pp n pp_root n
|
||||
in
|
||||
Fmt.fprintf out
|
||||
"(@[@{<yellow>cc.state@}@ (@[<hv>:nodes@ %a@])@ (@[<hv>:sig@ %a@])@])"
|
||||
(Util.pp_seq ~sep:" " pp_n) (Term.Tbl.values cc.tbl)
|
||||
(Util.pp_seq ~sep:" " pp_sig_e) (Sig_tbl.to_seq cc.signatures_tbl)
|
||||
|
||||
(* compute signature *)
|
||||
let signature cc (t:term): Signature.t option =
|
||||
let find = find_tn cc in
|
||||
begin match Term.view t with
|
||||
| App_cst (_, a) when IArray.is_empty a -> None
|
||||
| App_cst (c, _) when not @@ Cst.do_cc c -> None (* no CC *)
|
||||
| App_cst (f, a) -> Some (App_cst (f, IArray.map find a)) (* FIXME: relevance? *)
|
||||
| Bool _ | If _
|
||||
-> None (* no congruence for these *)
|
||||
end
|
||||
|
||||
(* find whether the given (parent) term corresponds to some signature
|
||||
in [signatures_] *)
|
||||
let find_by_signature cc (t:term) : repr option =
|
||||
match signature cc t with
|
||||
| None -> None
|
||||
| Some s -> Sig_tbl.get cc.signatures_tbl s
|
||||
|
||||
let add_signature cc (n:node): unit =
|
||||
match signature cc n.n_term with
|
||||
| None -> ()
|
||||
| Some s ->
|
||||
(* add, but only if not present already *)
|
||||
begin match Sig_tbl.find cc.signatures_tbl s with
|
||||
| exception Not_found ->
|
||||
Log.debugf 15
|
||||
(fun k->k "(@[cc.add_sig@ %a@ <--> %a@])" Signature.pp s N.pp n);
|
||||
on_backtrack cc (fun () -> Sig_tbl.remove cc.signatures_tbl s);
|
||||
Sig_tbl.add cc.signatures_tbl s n;
|
||||
| r' ->
|
||||
assert (same_class cc n r');
|
||||
end
|
||||
|
||||
let push_pending cc t : unit =
|
||||
if not @@ N.get_field N.field_is_pending t then (
|
||||
Log.debugf 5 (fun k->k "(@[<hv1>cc.push_pending@ %a@])" N.pp t);
|
||||
N.set_field N.field_is_pending true t;
|
||||
Vec.push cc.pending t
|
||||
)
|
||||
|
||||
let push_combine cc t u e : unit =
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[<hv1>cc.push_combine@ :t1 %a@ :t2 %a@ :expl %a@])"
|
||||
N.pp t N.pp u Explanation.pp (Lazy.force e));
|
||||
Vec.push cc.combine @@ CT_merge (t,u,e)
|
||||
|
||||
(* re-root the explanation tree of the equivalence class of [n]
|
||||
so that it points to [n].
|
||||
postcondition: [n.n_expl = None] *)
|
||||
let rec reroot_expl (cc:t) (n:node): unit =
|
||||
let old_expl = n.n_expl in
|
||||
begin match old_expl with
|
||||
| E_none -> () (* already root *)
|
||||
| E_some {next=u; expl=e_n_u} ->
|
||||
reroot_expl cc u;
|
||||
u.n_expl <- E_some {next=n; expl=e_n_u};
|
||||
n.n_expl <- E_none;
|
||||
end
|
||||
|
||||
let raise_conflict (cc:t) (acts:sat_actions) (e:conflict): _ =
|
||||
(* clear tasks queue *)
|
||||
Vec.iter (N.set_field N.field_is_pending false) cc.pending;
|
||||
Vec.clear cc.pending;
|
||||
Vec.clear cc.combine;
|
||||
let c = List.map Lit.neg e in
|
||||
acts.Msat.acts_raise_conflict c Proof_default
|
||||
|
||||
let[@inline] all_classes cc : repr Sequence.t =
|
||||
Term.Tbl.values cc.tbl
|
||||
|> Sequence.filter is_root_
|
||||
|
||||
(* TODO: use markers and lockstep iteration instead *)
|
||||
(* distance from [t] to its root in the proof forest *)
|
||||
let[@inline][@unroll 2] rec distance_to_root (n:node): int = match n.n_expl with
|
||||
| E_none -> 0
|
||||
| E_some {next=t'; _} -> 1 + distance_to_root t'
|
||||
|
||||
(* TODO: bool flag on nodes + stepwise progress + cleanup *)
|
||||
(* find the closest common ancestor of [a] and [b] in the proof forest *)
|
||||
let find_common_ancestor (a:node) (b:node) : node =
|
||||
let d_a = distance_to_root a in
|
||||
let d_b = distance_to_root b in
|
||||
(* drop [n] nodes in the path from [t] to its root *)
|
||||
let rec drop_ n t =
|
||||
if n=0 then t
|
||||
else match t.n_expl with
|
||||
| E_none -> assert false
|
||||
| E_some {next=t'; _} -> drop_ (n-1) t'
|
||||
in
|
||||
(* reduce to the problem where [a] and [b] have the same distance to root *)
|
||||
let a, b =
|
||||
if d_a > d_b then drop_ (d_a-d_b) a, b
|
||||
else if d_a < d_b then a, drop_ (d_b-d_a) b
|
||||
else a, b
|
||||
in
|
||||
(* traverse stepwise until a==b *)
|
||||
let rec aux_same_dist a b =
|
||||
if a==b then a
|
||||
else match a.n_expl, b.n_expl with
|
||||
| E_none, _ | _, E_none -> assert false
|
||||
| E_some {next=a'; _}, E_some {next=b'; _} -> aux_same_dist a' b'
|
||||
in
|
||||
aux_same_dist a b
|
||||
|
||||
let[@inline] ps_add_obligation (cc:t) a b = Vec.push cc.ps_queue (a,b)
|
||||
let[@inline] ps_add_lit ps l = ps.ps_lits <- Lit.Set.add l ps.ps_lits
|
||||
|
||||
let ps_clear (cc:t) =
|
||||
cc.ps_lits <- Lit.Set.empty;
|
||||
Vec.clear cc.ps_queue;
|
||||
()
|
||||
|
||||
let decompose_explain cc (e:explanation): unit =
|
||||
Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Explanation.pp e);
|
||||
begin match e with
|
||||
| E_reduction -> ()
|
||||
| E_lit lit -> ps_add_lit cc lit
|
||||
| E_lits l -> List.iter (ps_add_lit cc) l
|
||||
| E_merges l ->
|
||||
(* need to explain each merge in [l] *)
|
||||
IArray.iter (fun (t,u) -> ps_add_obligation cc t u) l
|
||||
end
|
||||
|
||||
(* explain why [a = parent_a], where [a -> ... -> parent_a] in the
|
||||
proof forest *)
|
||||
let rec explain_along_path ps (a:node) (parent_a:node) : unit =
|
||||
if a!=parent_a then (
|
||||
match a.n_expl with
|
||||
| E_none -> assert false
|
||||
| E_some {next=next_a; expl=e_a_b} ->
|
||||
decompose_explain ps e_a_b;
|
||||
(* now prove [next_a = parent_a] *)
|
||||
explain_along_path ps next_a parent_a
|
||||
)
|
||||
|
||||
(* find explanation *)
|
||||
let explain_loop (cc : t) : Lit.Set.t =
|
||||
while not (Vec.is_empty cc.ps_queue) do
|
||||
let a, b = Vec.pop cc.ps_queue in
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b);
|
||||
assert (N.equal (find cc a) (find cc b));
|
||||
let c = find_common_ancestor a b in
|
||||
explain_along_path cc a c;
|
||||
explain_along_path cc b c;
|
||||
done;
|
||||
cc.ps_lits
|
||||
|
||||
(* TODO: do not use ps_lits anymore *)
|
||||
let explain_eq_n ?(init=Lit.Set.empty) cc (n1:node) (n2:node) : Lit.Set.t =
|
||||
ps_clear cc;
|
||||
cc.ps_lits <- init;
|
||||
ps_add_obligation cc n1 n2;
|
||||
explain_loop cc
|
||||
|
||||
let explain_unfold ?(init=Lit.Set.empty) cc (e:explanation) : Lit.Set.t =
|
||||
ps_clear cc;
|
||||
cc.ps_lits <- init;
|
||||
decompose_explain cc e;
|
||||
explain_loop cc
|
||||
|
||||
(* add [tag] to [n], indicating that [n] is distinct from all the other
|
||||
nodes tagged with [tag]
|
||||
precond: [n] is a representative *)
|
||||
let add_tag_n cc (n:node) (tag:int) (expl:explanation) : unit =
|
||||
assert (is_root_ n);
|
||||
if not (Util.Int_map.mem tag n.n_tags) then (
|
||||
on_backtrack cc
|
||||
(fun () -> n.n_tags <- Util.Int_map.remove tag n.n_tags);
|
||||
n.n_tags <- Util.Int_map.add tag (n,expl) n.n_tags;
|
||||
)
|
||||
|
||||
(* TODO: payload for set of tags *)
|
||||
(* TODO: payload for mapping an equiv class to a set of literals, for bool prop *)
|
||||
|
||||
let relevant_subterms (t:Term.t) : Term.t Sequence.t =
|
||||
fun yield ->
|
||||
match t.term_view with
|
||||
| App_cst (c, a) when Cst.do_cc c -> IArray.iter yield a
|
||||
| Bool _ | App_cst _ -> ()
|
||||
| If (a,b,c) ->
|
||||
(* TODO: relevancy? only [a] needs be decided for now *)
|
||||
yield a;
|
||||
yield b;
|
||||
yield c
|
||||
|
||||
(* Checks if [ra] and [~into] have compatible normal forms and can
|
||||
be merged w.r.t. the theories.
|
||||
Side effect: also pushes sub-tasks *)
|
||||
let notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit =
|
||||
assert (is_root_ rb);
|
||||
match cc.on_merge with
|
||||
| Some f -> f ra rb e
|
||||
| None -> ()
|
||||
|
||||
(* main CC algo: add terms from [pending] to the signature table,
|
||||
check for collisions *)
|
||||
let rec update_tasks (cc:t) (acts:sat_actions) : unit =
|
||||
while not (Vec.is_empty cc.pending && Vec.is_empty cc.combine) do
|
||||
Vec.iter (task_pending_ cc) cc.pending;
|
||||
Vec.clear cc.pending;
|
||||
Vec.iter (task_combine_ cc acts) cc.combine;
|
||||
Vec.clear cc.combine;
|
||||
done
|
||||
|
||||
and task_pending_ cc n =
|
||||
N.set_field N.field_is_pending false n;
|
||||
(* check if some parent collided *)
|
||||
begin match find_by_signature cc n.n_term with
|
||||
| None ->
|
||||
(* add to the signature table [sig(n) --> n] *)
|
||||
add_signature cc n
|
||||
| Some u when n == u -> ()
|
||||
| Some u ->
|
||||
(* [t1] and [t2] must be applications of the same symbol to
|
||||
arguments that are pairwise equal *)
|
||||
assert (n != u);
|
||||
let expl = lazy (
|
||||
match n.n_term.term_view, u.n_term.term_view with
|
||||
| App_cst (f1, a1), App_cst (f2, a2) ->
|
||||
assert (Cst.equal f1 f2);
|
||||
assert (IArray.length a1 = IArray.length a2);
|
||||
(* TODO: just use "congruence" as explanation *)
|
||||
Explanation.mk_merges @@ IArray.map2 (fun u1 u2 -> add_term_rec_ cc u1, add_term_rec_ cc u2) a1 a2
|
||||
| If _, _ | App_cst _, _ | Bool _, _
|
||||
-> assert false
|
||||
) in
|
||||
push_combine cc n u expl
|
||||
end;
|
||||
(* TODO: evaluate [(= t u) := true] when [find t==find u] *)
|
||||
(* FIXME: when to actually evaluate?
|
||||
eval_pending cc;
|
||||
*)
|
||||
()
|
||||
|
||||
and[@inline] task_combine_ cc acts = function
|
||||
| CT_merge (a,b,e_ab) -> task_merge_ cc acts a b e_ab
|
||||
| CT_distinct (l,tag,e) -> task_distinct_ cc acts l tag e
|
||||
|
||||
(* main CC algo: merge equivalence classes in [st.combine].
|
||||
@raise Exn_unsat if merge fails *)
|
||||
and task_merge_ cc acts a b e_ab : unit =
|
||||
let ra = find cc a in
|
||||
let rb = find cc b in
|
||||
if not @@ N.equal ra rb then (
|
||||
assert (is_root_ ra);
|
||||
assert (is_root_ rb);
|
||||
let lazy e_ab = e_ab in
|
||||
(* We will merge [r_from] into [r_into].
|
||||
we try to ensure that [size ra <= size rb] in general, but always
|
||||
keep values as representative *)
|
||||
let r_from, r_into =
|
||||
if Term.is_value ra.n_term then rb, ra
|
||||
else if Term.is_value rb.n_term then ra, rb
|
||||
else if size_ ra > size_ rb then rb, ra
|
||||
else ra, rb
|
||||
in
|
||||
(* check we're not merging [true] and [false] *)
|
||||
if (N.equal ra (true_ cc) && N.equal rb (false_ cc)) ||
|
||||
(N.equal rb (true_ cc) && N.equal ra (false_ cc)) then (
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[<hv>cc.merge.true_false_conflict@ @[:r1 %a@]@ @[:r2 %a@]@ :e_ab %a@])"
|
||||
N.pp ra N.pp rb Explanation.pp e_ab);
|
||||
let lits = explain_unfold cc e_ab in
|
||||
let lits = explain_eq_n ~init:lits cc a ra in
|
||||
let lits = explain_eq_n ~init:lits cc b rb in
|
||||
raise_conflict cc acts @@ Lit.Set.elements lits
|
||||
);
|
||||
(* TODO: isntead call micro theories, including "distinct" *)
|
||||
(* update set of tags the new node cannot be equal to *)
|
||||
let new_tags =
|
||||
Util.Int_map.union
|
||||
(fun _i (n1,e1) (n2,e2) ->
|
||||
(* both maps contain same tag [_i]. conflict clause:
|
||||
[e1 & e2 & e_ab] impossible *)
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[<hv>cc.merge.distinct_conflict@ :tag %d@ \
|
||||
@[:r1 %a@ :e1 %a@]@ @[:r2 %a@ :e2 %a@]@ :e_ab %a@])"
|
||||
_i N.pp n1 Explanation.pp e1
|
||||
N.pp n2 Explanation.pp e2 Explanation.pp e_ab);
|
||||
let lits = explain_unfold cc e1 in
|
||||
let lits = explain_unfold ~init:lits cc e2 in
|
||||
let lits = explain_unfold ~init:lits cc e_ab in
|
||||
let lits = explain_eq_n ~init:lits cc a n1 in
|
||||
let lits = explain_eq_n ~init:lits cc b n2 in
|
||||
raise_conflict cc acts @@ Lit.Set.elements lits)
|
||||
ra.n_tags rb.n_tags
|
||||
in
|
||||
(* when merging terms with [true] or [false], possibly propagate them to SAT *)
|
||||
let merge_bool r1 t1 r2 t2 =
|
||||
if N.equal r1 (true_ cc) then (
|
||||
propagate_bools cc acts r2 t2 r1 t1 e_ab true
|
||||
) else if N.equal r1 (false_ cc) then (
|
||||
propagate_bools cc acts r2 t2 r1 t1 e_ab false
|
||||
)
|
||||
in
|
||||
merge_bool ra a rb b;
|
||||
merge_bool rb b ra a;
|
||||
(* perform [union r_from r_into] *)
|
||||
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
|
||||
(* TODO: only iterate on parents of [rb] *)
|
||||
(* TODO: [ra.parents <- ra.parent ++ rb.parents] *)
|
||||
begin
|
||||
(* for each node in [r_from]'s class:
|
||||
- make it point to [r_into]
|
||||
- push it into [st.pending] *)
|
||||
iter_class_ r_from
|
||||
(fun u ->
|
||||
assert (u.n_root == r_from);
|
||||
on_backtrack cc (fun () -> u.n_root <- r_from);
|
||||
u.n_root <- r_into;
|
||||
Bag.to_seq u.n_parents
|
||||
(fun parent -> push_pending cc parent));
|
||||
(* now merge the classes *)
|
||||
let r_into_old_tags = r_into.n_tags in
|
||||
let r_into_old_next = r_into.n_next in
|
||||
let r_from_old_next = r_from.n_next in
|
||||
on_backtrack cc
|
||||
(fun () ->
|
||||
Log.debugf 15
|
||||
(fun k->k "(@[cc.undo_merge@ :from %a :into %a@])"
|
||||
Term.pp r_from.n_term Term.pp r_into.n_term);
|
||||
r_into.n_next <- r_into_old_next;
|
||||
r_from.n_next <- r_from_old_next;
|
||||
r_into.n_tags <- r_into_old_tags);
|
||||
r_into.n_tags <- new_tags;
|
||||
(* swap [into.next] and [from.next], merging the classes *)
|
||||
r_into.n_next <- r_from_old_next;
|
||||
r_from.n_next <- r_into_old_next;
|
||||
end;
|
||||
(* update explanations (a -> b), arbitrarily.
|
||||
Note that here we merge the classes by adding a bridge between [a]
|
||||
and [b], not their roots. *)
|
||||
begin
|
||||
reroot_expl cc a;
|
||||
assert (a.n_expl = E_none);
|
||||
(* on backtracking, link may be inverted, but we delete the one
|
||||
that bridges between [a] and [b] *)
|
||||
on_backtrack cc
|
||||
(fun () -> match a.n_expl, b.n_expl with
|
||||
| E_some e, _ when N.equal e.next b -> a.n_expl <- E_none
|
||||
| _, E_some e when N.equal e.next a -> b.n_expl <- E_none
|
||||
| _ -> assert false);
|
||||
a.n_expl <- E_some {next=b; expl=e_ab};
|
||||
end;
|
||||
(* notify listeners of the merge *)
|
||||
notify_merge cc r_from ~into:r_into e_ab;
|
||||
)
|
||||
|
||||
and task_distinct_ cc acts (l:node list) tag expl : unit =
|
||||
let l = List.map (fun n -> n, find cc n) l in
|
||||
let coll =
|
||||
Sequence.diagonal_l l
|
||||
|> Sequence.find_pred (fun ((_,r1),(_,r2)) -> N.equal r1 r2)
|
||||
in
|
||||
begin match coll with
|
||||
| Some ((n1,_r1),(n2,_r2)) ->
|
||||
(* two classes are already equal *)
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[cc.distinct.conflict@ %a = %a@ :expl %a@])" N.pp n1 N.pp
|
||||
n2 Explanation.pp expl);
|
||||
let lits = explain_unfold cc expl in
|
||||
raise_conflict cc acts (Lit.Set.to_list lits)
|
||||
| None ->
|
||||
(* put a tag on all equivalence classes, that will make their merge fail *)
|
||||
List.iter (fun (_,n) -> add_tag_n cc n tag expl) l
|
||||
end
|
||||
|
||||
(* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1]
|
||||
in the equiv class of [r1] that is a known literal back to the SAT solver
|
||||
and which is not the one initially merged.
|
||||
We can explain the propagation with [u1 = t1 =e= t2 = r2==bool] *)
|
||||
and propagate_bools cc acts r1 t1 r2 t2 (e_12:explanation) sign : unit =
|
||||
(* explanation for [t1 =e= t2 = r2] *)
|
||||
let half_expl = lazy (
|
||||
let expl = explain_unfold cc e_12 in
|
||||
explain_eq_n ~init:expl cc r2 t2
|
||||
) in
|
||||
iter_class_ r1
|
||||
(fun u1 ->
|
||||
(* propagate if:
|
||||
- [u1] is a proper literal
|
||||
- [t2 != r2], because that can only happen
|
||||
after an explicit merge (no way to obtain that by propagation)
|
||||
*)
|
||||
if N.get_field N.field_is_literal u1 && not (N.equal r2 t2) then (
|
||||
let lit = Lit.atom ~sign u1.n_term in
|
||||
Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" Lit.pp lit);
|
||||
(* complete explanation with the [u1=t1] chunk *)
|
||||
let expl = explain_eq_n ~init:(Lazy.force half_expl) cc u1 t1 in
|
||||
let reason = Msat.Consequence (Lit.Set.to_list expl, Proof_default) in
|
||||
acts.Msat.acts_propagate lit reason
|
||||
))
|
||||
|
||||
(* add [t] to [cc] when not present already *)
|
||||
and add_new_term_ cc (t:term) : node =
|
||||
assert (not @@ mem cc t);
|
||||
Log.debugf 15 (fun k->k "(@[cc.add-term@ %a@])" Term.pp t);
|
||||
let n = N.make t in
|
||||
(* how to add a subterm *)
|
||||
let add_to_parents_of_sub_node (sub:node) : unit =
|
||||
let sub = find cc sub in (* update the repr! *)
|
||||
let old_parents = sub.n_parents in
|
||||
on_backtrack cc (fun () -> sub.n_parents <- old_parents);
|
||||
sub.n_parents <- Bag.cons n sub.n_parents;
|
||||
in
|
||||
(* add sub-term to [cc], and register [n] to its parents *)
|
||||
let add_sub_t (u:term) : unit =
|
||||
let n_u = add_term_rec_ cc u in
|
||||
add_to_parents_of_sub_node n_u
|
||||
in
|
||||
(* register sub-terms, add [t] to their parent list *)
|
||||
relevant_subterms t add_sub_t;
|
||||
(* remove term when we backtrack *)
|
||||
on_backtrack cc
|
||||
(fun () ->
|
||||
Log.debugf 15 (fun k->k "(@[cc.remove-term@ %a@])" Term.pp t);
|
||||
Term.Tbl.remove cc.tbl t);
|
||||
(* add term to the table *)
|
||||
Term.Tbl.add cc.tbl t n;
|
||||
(* [n] might be merged with other equiv classes *)
|
||||
push_pending cc n;
|
||||
n
|
||||
|
||||
(* add a term *)
|
||||
and[@inline] add_term_rec_ cc t : node =
|
||||
try Term.Tbl.find cc.tbl t
|
||||
with Not_found -> add_new_term_ cc t
|
||||
|
||||
let check_invariants_ (cc:t) =
|
||||
Log.debug 5 "(cc.check-invariants)";
|
||||
Log.debugf 15 (fun k-> k "%a" pp_full cc);
|
||||
assert (Term.equal (Term.true_ cc.tst) (true_ cc).n_term);
|
||||
assert (Term.equal (Term.false_ cc.tst) (false_ cc).n_term);
|
||||
assert (not @@ same_class cc (true_ cc) (false_ cc));
|
||||
assert (Vec.is_empty cc.combine);
|
||||
assert (Vec.is_empty cc.pending);
|
||||
(* check that subterms are internalized *)
|
||||
Term.Tbl.iter
|
||||
(fun t n ->
|
||||
assert (Term.equal t n.n_term);
|
||||
assert (not @@ N.get_field N.field_is_pending n);
|
||||
relevant_subterms t
|
||||
(fun u -> assert (Term.Tbl.mem cc.tbl u));
|
||||
assert (N.equal n.n_root n.n_next.n_root);
|
||||
(* check proper signature.
|
||||
note that some signatures in the sig table can be obsolete (they
|
||||
were not removed) but there must be a valid, up-to-date signature for
|
||||
each term *)
|
||||
begin match signature cc t with
|
||||
| None -> ()
|
||||
| Some s ->
|
||||
Log.debugf 15 (fun k->k "(@[cc.check-sig@ %a@ :sig %a@])" Term.pp t Signature.pp s);
|
||||
(* add, but only if not present already *)
|
||||
begin match Sig_tbl.find cc.signatures_tbl s with
|
||||
| exception Not_found -> assert false
|
||||
| repr_s -> assert (same_class cc n repr_s)
|
||||
end
|
||||
end;
|
||||
)
|
||||
cc.tbl;
|
||||
()
|
||||
|
||||
let[@inline] check_invariants (cc:t) : unit =
|
||||
if Util._CHECK_INVARIANTS then check_invariants_ cc
|
||||
|
||||
let[@inline] add cc t : node = add_term_rec_ cc t
|
||||
|
||||
let add_seq cc seq =
|
||||
seq (fun t -> ignore @@ add_term_rec_ cc t);
|
||||
()
|
||||
|
||||
let[@inline] push_level (self:t) : unit =
|
||||
Backtrack_stack.push_level self.undo
|
||||
|
||||
let pop_levels (self:t) n : unit =
|
||||
Vec.iter (N.set_field N.field_is_pending false) self.pending;
|
||||
Vec.clear self.pending;
|
||||
Vec.clear self.combine;
|
||||
Log.debugf 15
|
||||
(fun k->k "(@[cc.pop-levels %d@ :n-lvls %d@])" n (Backtrack_stack.n_levels self.undo));
|
||||
Backtrack_stack.pop_levels self.undo n ~f:(fun f -> f());
|
||||
()
|
||||
|
||||
(* TODO: if a lit is [= a b], merge [a] and [b];
|
||||
if it's [distinct a1…an], make them distinct, etc. etc. *)
|
||||
(* assert that this boolean literal holds *)
|
||||
let assert_lit cc lit : unit =
|
||||
let t = Lit.view lit in
|
||||
assert (Ty.is_prop t.term_ty);
|
||||
Log.debugf 5 (fun k->k "(@[cc.assert_lit@ %a@])" Lit.pp lit);
|
||||
let sign = Lit.sign lit in
|
||||
(* equate t and true/false *)
|
||||
let rhs = if sign then true_ cc else false_ cc in
|
||||
let n = add_term_rec_ cc t in
|
||||
(* TODO: ensure that this is O(1).
|
||||
basically, just have [n] point to true/false and thus acquire
|
||||
the corresponding value, so its superterms (like [ite]) can evaluate
|
||||
properly *)
|
||||
push_combine cc n rhs (Lazy.from_val @@ E_lit lit)
|
||||
|
||||
let[@inline] assert_lits cc lits : unit =
|
||||
Sequence.iter (assert_lit cc) lits
|
||||
|
||||
let assert_eq cc (t:term) (u:term) e : unit =
|
||||
let n1 = add_term_rec_ cc t in
|
||||
let n2 = add_term_rec_ cc u in
|
||||
if not (same_class cc n1 n2) then (
|
||||
let e = Lazy.from_val @@ Explanation.E_lits e in
|
||||
push_combine cc n1 n2 e;
|
||||
)
|
||||
|
||||
let assert_distinct cc (l:term list) ~neq (lit:Lit.t) : unit =
|
||||
assert (match l with[] | [_] -> false | _ -> true);
|
||||
let tag = Term.id neq in
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[cc.assert_distinct@ (@[%a@])@ :tag %d@])" (Util.pp_list Term.pp) l tag);
|
||||
let l = List.map (add cc) l in
|
||||
Vec.push cc.combine @@ CT_distinct (l, tag, Explanation.lit lit)
|
||||
|
||||
let create ?on_merge ?(size=`Big) (tst:Term.state) : t =
|
||||
let size = match size with `Small -> 128 | `Big -> 2048 in
|
||||
let rec cc = {
|
||||
tst;
|
||||
tbl = Term.Tbl.create size;
|
||||
signatures_tbl = Sig_tbl.create size;
|
||||
on_merge;
|
||||
pending=Vec.create();
|
||||
combine=Vec.create();
|
||||
ps_lits=Lit.Set.empty;
|
||||
undo=Backtrack_stack.create();
|
||||
ps_queue=Vec.create();
|
||||
true_;
|
||||
false_;
|
||||
} and true_ = lazy (
|
||||
add_term_rec_ cc (Term.true_ tst)
|
||||
) and false_ = lazy (
|
||||
add_term_rec_ cc (Term.false_ tst)
|
||||
)
|
||||
in
|
||||
ignore (Lazy.force true_ : node);
|
||||
ignore (Lazy.force false_ : node);
|
||||
cc
|
||||
|
||||
let[@inline] find_t cc t : repr =
|
||||
let n = Term.Tbl.find cc.tbl t in
|
||||
find cc n
|
||||
|
||||
let[@inline] check cc acts : unit =
|
||||
Log.debug 5 "(cc.check)";
|
||||
update_tasks cc acts
|
||||
|
||||
(* model: map each uninterpreted equiv class to some ID *)
|
||||
let mk_model (cc:t) (m:Model.t) : Model.t =
|
||||
Log.debugf 15 (fun k->k "(@[cc.mk_model@ %a@])" pp_full cc);
|
||||
(* populate [repr -> value] table *)
|
||||
let t_tbl = N.Tbl.create 32 in
|
||||
(* type -> default value *)
|
||||
let ty_tbl = Ty.Tbl.create 8 in
|
||||
Term.Tbl.values cc.tbl
|
||||
(fun r ->
|
||||
if is_root_ r then (
|
||||
let t = r.n_term in
|
||||
let v = match Model.eval m t with
|
||||
| Some v -> v
|
||||
| None ->
|
||||
if same_class cc r (true_ cc) then Value.true_
|
||||
else if same_class cc r (false_ cc) then Value.false_
|
||||
else (
|
||||
Value.mk_elt
|
||||
(ID.makef "v_%d" @@ Term.id t)
|
||||
(Term.ty r.n_term)
|
||||
)
|
||||
in
|
||||
if not @@ Ty.Tbl.mem ty_tbl (Term.ty t) then (
|
||||
Ty.Tbl.add ty_tbl (Term.ty t) v; (* also give a value to this type *)
|
||||
);
|
||||
N.Tbl.add t_tbl r v
|
||||
));
|
||||
(* now map every uninterpreted term to its representative's value, and
|
||||
create function tables *)
|
||||
let m, funs =
|
||||
Term.Tbl.to_seq cc.tbl
|
||||
|> Sequence.fold
|
||||
(fun (m,funs) (t,r) ->
|
||||
let r = find cc r in (* get representative *)
|
||||
match Term.view t with
|
||||
| _ when Model.mem t m -> m, funs
|
||||
| App_cst (c, args) ->
|
||||
if Model.mem t m then m, funs
|
||||
else if Cst.is_undefined c && IArray.length args > 0 then (
|
||||
(* update signature of [c] *)
|
||||
let ty = Term.ty t in
|
||||
let v = N.Tbl.find t_tbl r in
|
||||
let args =
|
||||
args
|
||||
|> IArray.map (fun t -> N.Tbl.find t_tbl @@ find_tn cc t)
|
||||
|> IArray.to_list
|
||||
in
|
||||
let ty, l = Cst.Map.get_or c funs ~default:(ty,[]) in
|
||||
m, Cst.Map.add c (ty, (args,v)::l) funs
|
||||
) else (
|
||||
let v = N.Tbl.find t_tbl r in
|
||||
Model.add t v m, funs
|
||||
)
|
||||
| _ ->
|
||||
let v = N.Tbl.find t_tbl r in
|
||||
Model.add t v m, funs)
|
||||
(m,Cst.Map.empty)
|
||||
in
|
||||
(* get or make a default value for this type *)
|
||||
let rec get_ty_default (ty:Ty.t) : Value.t =
|
||||
match Ty.view ty with
|
||||
| Ty_prop -> Value.true_
|
||||
| Ty_atomic { def = Ty_uninterpreted _;_} ->
|
||||
(* domain element *)
|
||||
Ty.Tbl.get_or_add ty_tbl ~k:ty
|
||||
~f:(fun ty -> Value.mk_elt (ID.makef "ty_%d" @@ Ty.id ty) ty)
|
||||
| Ty_atomic { def = Ty_def d; args; _} ->
|
||||
(* ask the theory for a default value *)
|
||||
Ty.Tbl.get_or_add ty_tbl ~k:ty
|
||||
~f:(fun _ty ->
|
||||
let vals = List.map get_ty_default args in
|
||||
d.default_val vals)
|
||||
in
|
||||
let funs =
|
||||
Cst.Map.map
|
||||
(fun (ty,l) ->
|
||||
Model.Fun_interpretation.make ~default:(get_ty_default ty) l)
|
||||
funs
|
||||
in
|
||||
Model.add_funs funs m
|
||||
|
|
@ -1,73 +0,0 @@
|
|||
(** {2 Congruence Closure} *)
|
||||
|
||||
open Solver_types
|
||||
|
||||
type t
|
||||
(** Global state of the congruence closure *)
|
||||
|
||||
type node = Eq_class.t
|
||||
(** Node in the congruence closure *)
|
||||
|
||||
type repr = Eq_class.t
|
||||
(** Node that is currently a representative *)
|
||||
|
||||
type conflict = Theory.conflict
|
||||
|
||||
val create :
|
||||
?on_merge:(repr -> repr -> explanation -> unit) ->
|
||||
?size:[`Small | `Big] ->
|
||||
Term.state ->
|
||||
t
|
||||
(** Create a new congruence closure.
|
||||
@param acts the actions available to the congruence closure
|
||||
*)
|
||||
|
||||
val find : t -> node -> repr
|
||||
(** Current representative *)
|
||||
|
||||
val add : t -> term -> node
|
||||
(** Add the term to the congruence closure, if not present already.
|
||||
Will be backtracked. *)
|
||||
|
||||
val find_t : t -> term -> repr
|
||||
(** Current representative of the term.
|
||||
@raise Not_found if the term is not already {!add}-ed. *)
|
||||
|
||||
val add_seq : t -> term Sequence.t -> unit
|
||||
(** Add a sequence of terms to the congruence closure *)
|
||||
|
||||
val all_classes : t -> repr Sequence.t
|
||||
(** All current classes *)
|
||||
|
||||
val assert_lit : t -> Lit.t -> unit
|
||||
(** Given a literal, assume it in the congruence closure and propagate
|
||||
its consequences. Will be backtracked. *)
|
||||
|
||||
val assert_lits : t -> Lit.t Sequence.t -> unit
|
||||
|
||||
val assert_eq : t -> term -> term -> Lit.t list -> unit
|
||||
|
||||
val assert_distinct : t -> term list -> neq:term -> Lit.t -> unit
|
||||
(** [assert_distinct l ~expl:u e] asserts all elements of [l] are distinct
|
||||
with explanation [e]
|
||||
precond: [u = distinct l] *)
|
||||
|
||||
val check : t -> sat_actions -> unit
|
||||
(** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc.
|
||||
Will use the [sat_actions] to propagate literals, declare conflicts, etc. *)
|
||||
|
||||
val push_level : t -> unit
|
||||
|
||||
val pop_levels : t -> int -> unit
|
||||
|
||||
val mk_model : t -> Model.t -> Model.t
|
||||
(** Enrich a model by mapping terms to their representative's value,
|
||||
if any. Otherwise map the representative to a fresh value *)
|
||||
|
||||
(**/**)
|
||||
val check_invariants : t -> unit
|
||||
val pp_full : t Fmt.printer
|
||||
(**/**)
|
||||
|
||||
module T_arg : Mini_cc_intf.ARG with type Fun.t = cst and type Term.t = Term.t
|
||||
module Mini_cc : module type of Mini_cc.Make(T_arg)
|
||||
|
|
@ -1,66 +0,0 @@
|
|||
|
||||
open Solver_types
|
||||
|
||||
type t = equiv_class
|
||||
type payload = equiv_class_payload = ..
|
||||
|
||||
let field_is_active = Node_bits.mk_field()
|
||||
let field_is_pending = Node_bits.mk_field()
|
||||
let field_is_literal = Node_bits.mk_field()
|
||||
let () = Node_bits.freeze()
|
||||
|
||||
let[@inline] equal (n1:t) n2 = n1==n2
|
||||
let[@inline] hash n = Term.hash n.n_term
|
||||
let[@inline] term n = n.n_term
|
||||
let[@inline] payload n = n.n_payload
|
||||
let[@inline] pp out n = Term.pp out n.n_term
|
||||
|
||||
let make (t:term) : t =
|
||||
let rec n = {
|
||||
n_term=t;
|
||||
n_bits=Node_bits.empty;
|
||||
n_parents=Bag.empty;
|
||||
n_root=n;
|
||||
n_expl=E_none;
|
||||
n_payload=[];
|
||||
n_next=n;
|
||||
n_size=1;
|
||||
n_tags=Util.Int_map.empty;
|
||||
} in
|
||||
n
|
||||
|
||||
let set_payload ?(can_erase=fun _->false) n e =
|
||||
let rec aux = function
|
||||
| [] -> [e]
|
||||
| e' :: tail when can_erase e' -> e :: tail
|
||||
| e' :: tail -> e' :: aux tail
|
||||
in
|
||||
n.n_payload <- aux n.n_payload
|
||||
|
||||
let payload_find ~f:p n =
|
||||
let[@unroll 2] rec aux = function
|
||||
| [] -> None
|
||||
| e1 :: tail ->
|
||||
match p e1 with
|
||||
| Some _ as res -> res
|
||||
| None -> aux tail
|
||||
in
|
||||
aux n.n_payload
|
||||
|
||||
let payload_pred ~f:p n =
|
||||
begin match n.n_payload with
|
||||
| [] -> false
|
||||
| e :: _ when p e -> true
|
||||
| _ :: e :: _ when p e -> true
|
||||
| _ :: _ :: e :: _ when p e -> true
|
||||
| l -> List.exists p l
|
||||
end
|
||||
|
||||
let[@inline] get_field f t = Node_bits.get f t.n_bits
|
||||
let[@inline] set_field f b t = t.n_bits <- Node_bits.set f b t.n_bits
|
||||
|
||||
module Tbl = CCHashtbl.Make(struct
|
||||
type t = equiv_class
|
||||
let equal = equal
|
||||
let hash = hash
|
||||
end)
|
||||
|
|
@ -1,61 +0,0 @@
|
|||
|
||||
open Solver_types
|
||||
|
||||
(** {1 Equivalence Classes} *)
|
||||
|
||||
(** An equivalence class is a set of terms that are currently equal
|
||||
in the partial model built by the solver.
|
||||
The class is represented by a collection of nodes, one of which is
|
||||
distinguished and is called the "representative".
|
||||
|
||||
All information pertaining to the whole equivalence class is stored
|
||||
in this representative's node.
|
||||
|
||||
When two classes become equal (are "merged"), one of the two
|
||||
representatives is picked as the representative of the new class.
|
||||
The new class contains the union of the two old classes' nodes.
|
||||
|
||||
We also allow theories to store additional information in the
|
||||
representative. This information can be used when two classes are
|
||||
merged, to detect conflicts and solve equations à la Shostak.
|
||||
*)
|
||||
|
||||
type t = equiv_class
|
||||
type payload = equiv_class_payload = ..
|
||||
|
||||
val field_is_active : Node_bits.field
|
||||
(** The term is needed for evaluation. We must try to evaluate it
|
||||
or to find a value for it using the theory *)
|
||||
|
||||
val field_is_pending : Node_bits.field
|
||||
(** true iff the node is in the [cc.pending] queue *)
|
||||
|
||||
val field_is_literal : Node_bits.field
|
||||
(** This term is a boolean literal, subject to propagations *)
|
||||
|
||||
(** {2 basics} *)
|
||||
|
||||
val term : t -> term
|
||||
val equal : t -> t -> bool
|
||||
val hash : t -> int
|
||||
val pp : t Fmt.printer
|
||||
val payload : t -> payload list
|
||||
|
||||
(** {2 Helpers} *)
|
||||
|
||||
val make : term -> t
|
||||
(** Make a new equivalence class whose representative is the given term *)
|
||||
|
||||
val payload_find: f:(payload -> 'a option) -> t -> 'a option
|
||||
|
||||
val payload_pred: f:(payload -> bool) -> t -> bool
|
||||
|
||||
val set_payload : ?can_erase:(payload -> bool) -> t -> payload -> unit
|
||||
(** Add given payload
|
||||
@param can_erase if provided, checks whether an existing value
|
||||
is to be replaced instead of adding a new entry *)
|
||||
|
||||
val get_field : Node_bits.field -> t -> bool
|
||||
val set_field : Node_bits.field -> bool -> t -> unit
|
||||
|
||||
module Tbl : CCHashtbl.S with type key = t
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
|
||||
open Solver_types
|
||||
|
||||
type t = explanation =
|
||||
| E_reduction (* by pure reduction, tautologically equal *)
|
||||
| E_merges of (equiv_class * equiv_class) IArray.t (* caused by these merges *)
|
||||
| E_lit of lit (* because of this literal *)
|
||||
| E_lits of lit list (* because of this (true) conjunction *)
|
||||
|
||||
let compare = cmp_exp
|
||||
let equal a b = cmp_exp a b = 0
|
||||
|
||||
let pp = pp_explanation
|
||||
|
||||
let mk_merges l : t = E_merges l
|
||||
let mk_lit l : t = E_lit l
|
||||
let mk_lits = function [x] -> mk_lit x | l -> E_lits l
|
||||
let mk_reduction : t = E_reduction
|
||||
|
||||
let[@inline] lit l : t = E_lit l
|
||||
|
||||
module Set = CCSet.Make(struct
|
||||
type t = explanation
|
||||
let compare = compare
|
||||
end)
|
||||
|
||||
|
|
@ -8,7 +8,7 @@ type t = lit = {
|
|||
|
||||
let[@inline] neg l = {l with lit_sign=not l.lit_sign}
|
||||
let[@inline] sign t = t.lit_sign
|
||||
let[@inline] view (t:t): term = t.lit_term
|
||||
let[@inline] term (t:t): term = t.lit_term
|
||||
|
||||
let[@inline] abs t: t = {t with lit_sign=true}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ type t = lit = {
|
|||
val neg : t -> t
|
||||
val abs : t -> t
|
||||
val sign : t -> bool
|
||||
val view : t -> term
|
||||
val term : t -> term
|
||||
val as_atom : t -> term * bool
|
||||
val atom : ?sign:bool -> term -> t
|
||||
val hash : t -> int
|
||||
|
|
|
|||
|
|
@ -1,18 +0,0 @@
|
|||
|
||||
(** {1 Mini congruence closure} *)
|
||||
|
||||
type ('f, 't, 'ts) view = ('f, 't, 'ts) Mini_cc_intf.view =
|
||||
| Bool of bool
|
||||
| App of 'f * 'ts
|
||||
| If of 't * 't * 't
|
||||
|
||||
type res = Mini_cc_intf.res =
|
||||
| Sat
|
||||
| Unsat
|
||||
|
||||
module type ARG = Mini_cc_intf.ARG
|
||||
module type S = Mini_cc_intf.S
|
||||
|
||||
module Make(A: ARG)
|
||||
: S with type term = A.Term.t
|
||||
and type fun_ = A.Fun.t
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
|
||||
type ('f, 't, 'ts) view =
|
||||
| Bool of bool
|
||||
| App of 'f * 'ts
|
||||
| If of 't * 't * 't
|
||||
|
||||
(* TODO: also HO app, Eq, Distinct cases?
|
||||
-> then API that just adds boolean terms and does the right thing in case of
|
||||
Eq/Distinct *)
|
||||
|
||||
type res =
|
||||
| Sat
|
||||
| Unsat
|
||||
|
||||
module type ARG = sig
|
||||
module Fun : sig
|
||||
type t
|
||||
val equal : t -> t -> bool
|
||||
val hash : t -> int
|
||||
val pp : t Fmt.printer
|
||||
end
|
||||
|
||||
module Term : sig
|
||||
type t
|
||||
val equal : t -> t -> bool
|
||||
val hash : t -> int
|
||||
val pp : t Fmt.printer
|
||||
|
||||
(** View the term through the lens of the congruence closure *)
|
||||
val view : t -> (Fun.t, t, t Sequence.t) view
|
||||
end
|
||||
end
|
||||
|
||||
module type S = sig
|
||||
type term
|
||||
type fun_
|
||||
|
||||
type t
|
||||
|
||||
val create : unit -> t
|
||||
|
||||
val merge : t -> term -> term -> unit
|
||||
val distinct : t -> term list -> unit
|
||||
|
||||
val check : t -> res
|
||||
end
|
||||
|
||||
|
|
@ -58,6 +58,24 @@ let empty : t = {
|
|||
funs=Cst.Map.empty;
|
||||
}
|
||||
|
||||
(* FIXME: ues this to allocate a default value for each sort
|
||||
(* get or make a default value for this type *)
|
||||
let rec get_ty_default (ty:Ty.t) : Value.t =
|
||||
match Ty.view ty with
|
||||
| Ty_prop -> Value.true_
|
||||
| Ty_atomic { def = Ty_uninterpreted _;_} ->
|
||||
(* domain element *)
|
||||
Ty_tbl.get_or_add ty_tbl ~k:ty
|
||||
~f:(fun ty -> Value.mk_elt (ID.makef "ty_%d" @@ Ty.id ty) ty)
|
||||
| Ty_atomic { def = Ty_def d; args; _} ->
|
||||
(* ask the theory for a default value *)
|
||||
Ty_tbl.get_or_add ty_tbl ~k:ty
|
||||
~f:(fun _ty ->
|
||||
let vals = List.map get_ty_default args in
|
||||
d.default_val vals)
|
||||
in
|
||||
*)
|
||||
|
||||
let[@inline] mem t m = Term.Map.mem t m.values
|
||||
let[@inline] find t m = Term.Map.get t m.values
|
||||
|
||||
|
|
@ -102,9 +120,9 @@ let add_funs fs m : t = merge {values=Term.Map.empty; funs=fs} m
|
|||
|
||||
let pp out {values; funs} =
|
||||
let module FI = Fun_interpretation in
|
||||
let pp_tv out (t,v) = Fmt.fprintf out "(@[%a@ %a@])" Term.pp t Value.pp v in
|
||||
let pp_tv out (t,v) = Fmt.fprintf out "(@[%a@ := %a@])" Term.pp t Value.pp v in
|
||||
let pp_fun_entry out (vals,ret) =
|
||||
Format.fprintf out "(@[%a@ %a@])" (Fmt.Dump.list Value.pp) vals Value.pp ret
|
||||
Format.fprintf out "(@[%a@ := %a@])" (Fmt.Dump.list Value.pp) vals Value.pp ret
|
||||
in
|
||||
let pp_fun out (c, fi: Cst.t * FI.t) =
|
||||
Format.fprintf out "(@[<hov>%a :default %a@ %a@])"
|
||||
|
|
@ -127,6 +145,10 @@ let eval (m:t) (t:Term.t) : Value.t option =
|
|||
| V_bool false -> aux c
|
||||
| v -> Error.errorf "@[Model: wrong value@ for boolean %a@ %a@]" Term.pp a Value.pp v
|
||||
end
|
||||
| Eq(a,b) ->
|
||||
let a = aux a in
|
||||
let b = aux b in
|
||||
if Value.equal a b then Value.true_ else Value.false_
|
||||
| App_cst (c, args) ->
|
||||
begin try Term.Map.find t m.values
|
||||
with Not_found ->
|
||||
|
|
|
|||
|
|
@ -37,10 +37,6 @@ val empty : t
|
|||
|
||||
val add : Term.t -> Value.t -> t -> t
|
||||
|
||||
val add_fun : Cst.t -> Fun_interpretation.t -> t -> t
|
||||
|
||||
val add_funs : Fun_interpretation.t Cst.Map.t -> t -> t
|
||||
|
||||
val mem : Term.t -> t -> bool
|
||||
|
||||
val find : Term.t -> t -> Value.t option
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ module Solver = Solver
|
|||
module Solver_types = Solver_types
|
||||
|
||||
(**/**)
|
||||
module Bag = Bag
|
||||
module Vec = Msat.Vec
|
||||
module Log = Msat.Log
|
||||
(**/**)
|
||||
|
|
|
|||
|
|
@ -208,11 +208,9 @@ let assume (self:t) (c:Lit.t IArray.t) : unit =
|
|||
let c = IArray.to_array_map (Sat_solver.make_atom sat) c in
|
||||
Sat_solver.add_clause_a sat c Proof_default
|
||||
|
||||
let[@inline] assume_eq self t u expl : unit =
|
||||
Congruence_closure.assert_eq (cc self) t u [expl]
|
||||
|
||||
(* TODO: remove? use a special constant + micro theory instead? *)
|
||||
let[@inline] assume_distinct self l ~neq lit : unit =
|
||||
Congruence_closure.assert_distinct (cc self) l lit ~neq
|
||||
CC.assert_distinct (cc self) l lit ~neq
|
||||
|
||||
let check_model (_s:t) : unit =
|
||||
Log.debug 1 "(smt.solver.check-model)";
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ val create :
|
|||
val solver : t -> Sat_solver.t
|
||||
val th_combine : t -> Theory_combine.t
|
||||
val add_theory : t -> Theory.t -> unit
|
||||
val cc : t -> Congruence_closure.t
|
||||
val cc : t -> CC.t
|
||||
val stats : t -> Stat.t
|
||||
val tst : t -> Term.state
|
||||
|
||||
|
|
@ -56,7 +56,6 @@ val mk_atom_t : t -> ?sign:bool -> Term.t -> Atom.t
|
|||
|
||||
val assume : t -> Lit.t IArray.t -> unit
|
||||
|
||||
val assume_eq : t -> Term.t -> Term.t -> Lit.t -> unit
|
||||
val assume_distinct : t -> Term.t list -> neq:Term.t -> Lit.t -> unit
|
||||
|
||||
val solve :
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ module Vec = Msat.Vec
|
|||
module Log = Msat.Log
|
||||
|
||||
module Fmt = CCFormat
|
||||
module Node_bits = CCBitField.Make(struct end)
|
||||
|
||||
(* for objects that are expanded on demand only *)
|
||||
type 'a lazily_expanded =
|
||||
|
|
@ -21,43 +20,9 @@ type term = {
|
|||
and 'a term_view =
|
||||
| Bool of bool
|
||||
| App_cst of cst * 'a IArray.t (* full, first-order application *)
|
||||
| Eq of 'a * 'a
|
||||
| If of 'a * 'a * 'a
|
||||
|
||||
(** A node of the congruence closure.
|
||||
An equivalence class is represented by its "root" element,
|
||||
the representative.
|
||||
|
||||
If there is a normal form in the congruence class, then the
|
||||
representative is a normal form *)
|
||||
and equiv_class = {
|
||||
n_term: term;
|
||||
mutable n_bits: Node_bits.t; (* bitfield for various properties *)
|
||||
mutable n_parents: equiv_class Bag.t; (* parent terms of this node *)
|
||||
mutable n_root: equiv_class; (* representative of congruence class (itself if a representative) *)
|
||||
mutable n_next: equiv_class; (* pointer to next element of congruence class *)
|
||||
mutable n_size: int; (* size of the class *)
|
||||
mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *)
|
||||
mutable n_payload: equiv_class_payload list; (* list of theory payloads *)
|
||||
mutable n_tags: (equiv_class * explanation) Util.Int_map.t; (* "distinct" tags (i.e. set of `(distinct t1…tn)` terms this belongs to *)
|
||||
}
|
||||
|
||||
(** Theory-extensible payloads *)
|
||||
and equiv_class_payload = ..
|
||||
|
||||
and explanation_forest_link =
|
||||
| E_none
|
||||
| E_some of {
|
||||
next: equiv_class;
|
||||
expl: explanation;
|
||||
}
|
||||
|
||||
(* atomic explanation in the congruence closure *)
|
||||
and explanation =
|
||||
| E_reduction (* by pure reduction, tautologically equal *)
|
||||
| E_merges of (equiv_class * equiv_class) IArray.t (* caused by these merges *)
|
||||
| E_lit of lit (* because of this literal *)
|
||||
| E_lits of lit list (* because of this (true) conjunction *)
|
||||
|
||||
(* boolean literal *)
|
||||
and lit = {
|
||||
lit_term: term;
|
||||
|
|
@ -157,23 +122,6 @@ let hash_lit a =
|
|||
let sign = a.lit_sign in
|
||||
Hash.combine3 2 (Hash.bool sign) (term_hash_ a.lit_term)
|
||||
|
||||
let cmp_cc_node a b = term_cmp_ a.n_term b.n_term
|
||||
|
||||
let cmp_exp a b =
|
||||
let toint = function
|
||||
| E_merges _ -> 0 | E_lit _ -> 1
|
||||
| E_reduction -> 2 | E_lits _ -> 3
|
||||
in
|
||||
begin match a, b with
|
||||
| E_merges l1, E_merges l2 ->
|
||||
IArray.compare (CCOrd.pair cmp_cc_node cmp_cc_node) l1 l2
|
||||
| E_reduction, E_reduction -> 0
|
||||
| E_lit l1, E_lit l2 -> cmp_lit l1 l2
|
||||
| E_lits l1, E_lits l2 -> CCList.compare cmp_lit l1 l2
|
||||
| E_merges _, _ | E_lit _, _ | E_lits _, _ | E_reduction, _
|
||||
-> CCInt.compare (toint a)(toint b)
|
||||
end
|
||||
|
||||
let pp_cst out a = ID.pp out a.cst_id
|
||||
let id_of_cst a = a.cst_id
|
||||
|
||||
|
|
@ -215,6 +163,7 @@ let pp_term_view_gen ~pp_id ~pp_t out = function
|
|||
pp_id out (id_of_cst c)
|
||||
| App_cst (f,l) ->
|
||||
Fmt.fprintf out "(@[<1>%a@ %a@])" pp_id (id_of_cst f) (Util.pp_iarray pp_t) l
|
||||
| Eq (a,b) -> Fmt.fprintf out "(@[<hv>=@ %a@ %a@])" pp_t a pp_t b
|
||||
| If (a, b, c) ->
|
||||
Fmt.fprintf out "(@[if %a@ %a@ %a@])" pp_t a pp_t b pp_t c
|
||||
|
||||
|
|
@ -233,14 +182,5 @@ let pp_lit out l =
|
|||
if l.lit_sign then pp_term out l.lit_term
|
||||
else Format.fprintf out "(@[@<1>¬@ %a@])" pp_term l.lit_term
|
||||
|
||||
let pp_cc_node out n = pp_term out n.n_term
|
||||
|
||||
let pp_explanation out (e:explanation) = match e with
|
||||
| E_reduction -> Fmt.string out "reduction"
|
||||
| E_lit lit -> pp_lit out lit
|
||||
| E_lits l -> CCFormat.Dump.list pp_lit out l
|
||||
| E_merges l ->
|
||||
Format.fprintf out "(@[<hv1>merges@ %a@])"
|
||||
Fmt.(seq ~sep:(return "@ ") @@ within "[" "]" @@ hvbox @@
|
||||
pair ~sep:(return "@ <-> ") pp_cc_node pp_cc_node)
|
||||
(IArray.to_seq l)
|
||||
let pp_proof out = function
|
||||
| Proof_default -> Fmt.fprintf out "<default proof>"
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ type t = term = {
|
|||
type 'a view = 'a term_view =
|
||||
| Bool of bool
|
||||
| App_cst of cst * 'a IArray.t
|
||||
| Eq of 'a * 'a
|
||||
| If of 'a * 'a * 'a
|
||||
|
||||
let[@inline] id t = t.term_id
|
||||
|
|
@ -47,6 +48,7 @@ let[@inline] make st (c:t term_view) : t =
|
|||
|
||||
let[@inline] true_ st = Lazy.force st.true_
|
||||
let[@inline] false_ st = Lazy.force st.false_
|
||||
let bool st b = if b then true_ st else false_ st
|
||||
|
||||
let create ?(size=1024) () : state =
|
||||
let rec st ={
|
||||
|
|
@ -66,9 +68,9 @@ let app_cst st f a =
|
|||
let cell = Term_cell.app_cst f a in
|
||||
make st cell
|
||||
|
||||
let const st c = app_cst st c IArray.empty
|
||||
|
||||
let if_ st a b c = make st (Term_cell.if_ a b c)
|
||||
let[@inline] const st c = app_cst st c IArray.empty
|
||||
let[@inline] if_ st a b c = make st (Term_cell.if_ a b c)
|
||||
let[@inline] eq st a b = make st (Term_cell.eq a b)
|
||||
|
||||
(* "eager" and, evaluating [a] first *)
|
||||
let and_eager st a b = if_ st a b (false_ st)
|
||||
|
|
@ -87,10 +89,12 @@ let[@inline] is_const t = match view t with
|
|||
| _ -> false
|
||||
|
||||
let cc_view (t:t) =
|
||||
let module C = Mini_cc in
|
||||
let module C = Sidekick_cc in
|
||||
match view t with
|
||||
| Bool b -> C.Bool b
|
||||
| App_cst (f,args) -> C.App (f, IArray.to_seq args)
|
||||
| App_cst (f,_) when not (Cst.do_cc f) -> C.Opaque t (* skip *)
|
||||
| App_cst (f,args) -> C.App_fun (f, IArray.to_seq args)
|
||||
| Eq (a,b) -> C.Eq (a, b)
|
||||
| If (a,b,c) -> C.If (a,b,c)
|
||||
|
||||
module As_key = struct
|
||||
|
|
@ -109,6 +113,7 @@ let to_seq t yield =
|
|||
match view t with
|
||||
| Bool _ -> ()
|
||||
| App_cst (_,a) -> IArray.iter aux a
|
||||
| Eq (a,b) -> aux a; aux b
|
||||
| If (a,b,c) -> aux a; aux b; aux c
|
||||
in
|
||||
aux t
|
||||
|
|
@ -121,3 +126,14 @@ let as_cst_undef (t:term): (cst * Ty.Fun.t) option =
|
|||
|
||||
let pp = Solver_types.pp_term
|
||||
|
||||
|
||||
(* TODO
|
||||
module T_arg = struct
|
||||
module Fun = Cst
|
||||
module Term = struct
|
||||
include Term
|
||||
let view = cc_view
|
||||
end
|
||||
end
|
||||
module Mini_cc = Mini_cc.Make(T_arg)
|
||||
*)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ type t = term = {
|
|||
type 'a view = 'a term_view =
|
||||
| Bool of bool
|
||||
| App_cst of cst * 'a IArray.t
|
||||
| Eq of 'a * 'a
|
||||
| If of 'a * 'a * 'a
|
||||
|
||||
val id : t -> int
|
||||
|
|
@ -26,8 +27,10 @@ val create : ?size:int -> unit -> state
|
|||
val make : state -> t view -> t
|
||||
val true_ : state -> t
|
||||
val false_ : state -> t
|
||||
val bool : state -> bool -> t
|
||||
val const : state -> cst -> t
|
||||
val app_cst : state -> cst -> t IArray.t -> t
|
||||
val eq : state -> t -> t -> t
|
||||
val if_: state -> t -> t -> t -> t
|
||||
val and_eager : state -> t -> t -> t (* evaluate left argument first *)
|
||||
|
||||
|
|
@ -49,7 +52,7 @@ val is_true : t -> bool
|
|||
val is_false : t -> bool
|
||||
val is_const : t -> bool
|
||||
|
||||
val cc_view : t -> (cst,t,t Sequence.t) Mini_cc.view
|
||||
val cc_view : t -> (cst,t,t Sequence.t) Sidekick_cc.view
|
||||
|
||||
(* return [Some] iff the term is an undefined constant *)
|
||||
val as_cst_undef : t -> (cst * Ty.Fun.t) option
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ open Solver_types
|
|||
type 'a view = 'a Solver_types.term_view =
|
||||
| Bool of bool
|
||||
| App_cst of cst * 'a IArray.t
|
||||
| Eq of 'a * 'a
|
||||
| If of 'a * 'a * 'a
|
||||
|
||||
type t = term view
|
||||
|
|
@ -25,6 +26,7 @@ module Make_eq(A : ARG) = struct
|
|||
| Bool b -> Hash.bool b
|
||||
| App_cst (f,l) ->
|
||||
Hash.combine3 4 (Cst.hash f) (Hash.iarray sub_hash l)
|
||||
| Eq (a,b) -> Hash.combine3 12 (sub_hash a) (sub_hash b)
|
||||
| If (a,b,c) -> Hash.combine4 7 (sub_hash a) (sub_hash b) (sub_hash c)
|
||||
|
||||
(* equality that relies on physical equality of subterms *)
|
||||
|
|
@ -32,9 +34,10 @@ module Make_eq(A : ARG) = struct
|
|||
| Bool b1, Bool b2 -> CCBool.equal b1 b2
|
||||
| App_cst (f1, a1), App_cst (f2, a2) ->
|
||||
Cst.equal f1 f2 && IArray.equal sub_eq a1 a2
|
||||
| Eq(a1,b1), Eq(a2,b2) -> sub_eq a1 a2 && sub_eq b1 b2
|
||||
| If (a1,b1,c1), If (a2,b2,c2) ->
|
||||
sub_eq a1 a2 && sub_eq b1 b2 && sub_eq c1 c2
|
||||
| Bool _, _ | App_cst _, _ | If _, _
|
||||
| Bool _, _ | App_cst _, _ | If _, _ | Eq _, _
|
||||
-> false
|
||||
|
||||
let pp = Solver_types.pp_term_view_gen ~pp_id:ID.pp_name ~pp_t:A.pp
|
||||
|
|
@ -53,17 +56,25 @@ let false_ = Bool false
|
|||
let is_value = function
|
||||
| Bool _ -> true
|
||||
| App_cst ({cst_view=Cst_def r;_}, _) -> r.is_value
|
||||
| If _ | App_cst _ -> false
|
||||
| If _ | App_cst _ | Eq _ -> false
|
||||
|
||||
let app_cst f a = App_cst (f, a)
|
||||
let const c = App_cst (c, IArray.empty)
|
||||
let eq a b =
|
||||
if term_equal_ a b then (
|
||||
Bool true
|
||||
) else (
|
||||
(* canonize *)
|
||||
let a,b = if a.term_id > b.term_id then b, a else a, b in
|
||||
Eq (a,b)
|
||||
)
|
||||
|
||||
let if_ a b c =
|
||||
assert (Ty.equal b.term_ty c.term_ty);
|
||||
If (a,b,c)
|
||||
|
||||
let ty (t:t): Ty.t = match t with
|
||||
| Bool _ -> Ty.prop
|
||||
| Bool _ | Eq _ -> Ty.prop
|
||||
| App_cst (f, args) ->
|
||||
begin match Cst.view f with
|
||||
| Cst_undef fty ->
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ open Solver_types
|
|||
type 'a view = 'a Solver_types.term_view =
|
||||
| Bool of bool
|
||||
| App_cst of cst * 'a IArray.t
|
||||
| Eq of 'a * 'a
|
||||
| If of 'a * 'a * 'a
|
||||
|
||||
type t = term view
|
||||
|
|
@ -15,6 +16,7 @@ val true_ : t
|
|||
val false_ : t
|
||||
val const : cst -> t
|
||||
val app_cst : cst -> term IArray.t -> t
|
||||
val eq : term -> term -> t
|
||||
val if_ : term -> term -> term -> t
|
||||
|
||||
val is_value : t -> bool
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@ end
|
|||
Its negation will become a conflict clause *)
|
||||
type conflict = Lit.t list
|
||||
|
||||
module CC_eq_class = CC.N
|
||||
module CC_expl = CC.Expl
|
||||
|
||||
(** Actions available to a theory during its lifetime *)
|
||||
module type ACTIONS = sig
|
||||
val raise_conflict: conflict -> 'a
|
||||
|
|
@ -41,12 +44,15 @@ module type ACTIONS = sig
|
|||
(** Add toplevel clause to the SAT solver. This clause will
|
||||
not be backtracked. *)
|
||||
|
||||
val find: Term.t -> Eq_class.t
|
||||
(** Find representative of this term *)
|
||||
val cc_add_term: Term.t -> CC_eq_class.t
|
||||
(** add/get term to the congruence closure *)
|
||||
|
||||
val all_classes: Eq_class.t Sequence.t
|
||||
val cc_find: CC_eq_class.t -> CC_eq_class.t
|
||||
(** Find representative of this in the congruence closure *)
|
||||
|
||||
val cc_all_classes: CC_eq_class.t Sequence.t
|
||||
(** All current equivalence classes
|
||||
(caution: linear in the number of terms existing in the solver) *)
|
||||
(caution: linear in the number of terms existing in the congruence closure) *)
|
||||
end
|
||||
|
||||
type actions = (module ACTIONS)
|
||||
|
|
@ -60,7 +66,7 @@ module type S = sig
|
|||
val create : Term.state -> t
|
||||
(** Instantiate the theory's state *)
|
||||
|
||||
val on_merge: t -> actions -> Eq_class.t -> Eq_class.t -> Explanation.t -> unit
|
||||
val on_merge: t -> actions -> CC_eq_class.t -> CC_eq_class.t -> CC_expl.t -> unit
|
||||
(** Called when two classes are merged *)
|
||||
|
||||
val partial_check : t -> actions -> Lit.t Sequence.t -> unit
|
||||
|
|
@ -70,7 +76,7 @@ module type S = sig
|
|||
(** Final check, must be complete (i.e. must raise a conflict
|
||||
if the set of literals is not satisfiable) *)
|
||||
|
||||
val mk_model : t -> Lit.t Sequence.t -> Model.t
|
||||
val mk_model : t -> Lit.t Sequence.t -> Model.t -> Model.t
|
||||
(** Make a model for this theory's terms *)
|
||||
|
||||
val push_level : t -> unit
|
||||
|
|
@ -91,7 +97,7 @@ let make
|
|||
?(check_invariants=fun _ -> ())
|
||||
?(on_merge=fun _ _ _ _ _ -> ())
|
||||
?(partial_check=fun _ _ _ -> ())
|
||||
?(mk_model=fun _ _ -> Model.empty)
|
||||
?(mk_model=fun _ _ m -> m)
|
||||
?(push_level=fun _ -> ())
|
||||
?(pop_levels=fun _ _ -> ())
|
||||
~name
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
(** Combine the congruence closure with a number of plugins *)
|
||||
|
||||
module C_clos = Congruence_closure
|
||||
open Solver_types
|
||||
|
||||
module Proof = struct
|
||||
|
|
@ -12,6 +11,8 @@ module Proof = struct
|
|||
end
|
||||
|
||||
module Formula = Lit
|
||||
module Eq_class = CC.N
|
||||
module Expl = CC.Expl
|
||||
|
||||
type formula = Lit.t
|
||||
type proof = Proof.t
|
||||
|
|
@ -24,11 +25,11 @@ type theory_state =
|
|||
type t = {
|
||||
tst: Term.state;
|
||||
(** state for managing terms *)
|
||||
cc: C_clos.t lazy_t;
|
||||
cc: CC.t lazy_t;
|
||||
(** congruence closure *)
|
||||
mutable theories : theory_state list;
|
||||
(** Set of theories *)
|
||||
new_merges: (Eq_class.t * Eq_class.t * explanation) Vec.t;
|
||||
new_merges: (Eq_class.t * Eq_class.t * Expl.t) Vec.t;
|
||||
}
|
||||
|
||||
let[@inline] cc (t:t) = Lazy.force t.cc
|
||||
|
|
@ -41,24 +42,28 @@ let[@inline] theories (self:t) : theory_state Sequence.t =
|
|||
(* handle a literal assumed by the SAT solver *)
|
||||
let assert_lits_ ~final (self:t) acts (lits:Lit.t Sequence.t) : unit =
|
||||
Msat.Log.debugf 2
|
||||
(fun k->k "(@[<1>@{<green>th_combine.assume_lits@}@ @[%a@]@])" (Fmt.seq Lit.pp) lits);
|
||||
(fun k->k "(@[<hv1>@{<green>th_combine.assume_lits@}@ %a@])"
|
||||
(Util.pp_seq ~sep:";" Lit.pp) lits);
|
||||
(* transmit to CC *)
|
||||
Vec.clear self.new_merges;
|
||||
let cc = cc self in
|
||||
C_clos.assert_lits cc lits;
|
||||
if not final then (
|
||||
CC.assert_lits cc lits;
|
||||
);
|
||||
(* transmit to theories. *)
|
||||
C_clos.check cc acts;
|
||||
CC.check cc acts;
|
||||
let module A = struct
|
||||
let[@inline] raise_conflict c : 'a = acts.Msat.acts_raise_conflict c Proof_default
|
||||
let[@inline] propagate_eq t u expl : unit = C_clos.assert_eq cc t u expl
|
||||
let propagate_distinct ts ~neq expl = C_clos.assert_distinct cc ts ~neq expl
|
||||
let[@inline] propagate_eq t u expl : unit = CC.assert_eq cc t u expl
|
||||
let propagate_distinct ts ~neq expl = CC.assert_distinct cc ts ~neq expl
|
||||
let[@inline] propagate p cs : unit = acts.Msat.acts_propagate p (Msat.Consequence (cs, Proof_default))
|
||||
let[@inline] add_local_axiom lits : unit =
|
||||
acts.Msat.acts_add_clause ~keep:false lits Proof_default
|
||||
let[@inline] add_persistent_axiom lits : unit =
|
||||
acts.Msat.acts_add_clause ~keep:true lits Proof_default
|
||||
let[@inline] find t = C_clos.find_t cc t
|
||||
let all_classes = C_clos.all_classes cc
|
||||
let[@inline] cc_add_term t = CC.add_term cc t
|
||||
let[@inline] cc_find t = CC.find cc t
|
||||
let cc_all_classes = CC.all_classes cc
|
||||
end in
|
||||
let acts = (module A : Theory.ACTIONS) in
|
||||
theories self
|
||||
|
|
@ -83,10 +88,10 @@ let check_ ~final (self:t) (acts:_ Msat.acts) =
|
|||
assert_lits_ ~final self acts iter
|
||||
|
||||
let add_formula (self:t) (lit:Lit.t) =
|
||||
let t = Lit.view lit in
|
||||
let t = Lit.term lit in
|
||||
let lazy cc = self.cc in
|
||||
let n = C_clos.add cc t in
|
||||
Eq_class.set_field Eq_class.field_is_literal true n;
|
||||
let n = CC.add_term cc t in
|
||||
CC.set_as_lit cc n (Lit.abs lit);
|
||||
()
|
||||
|
||||
(* propagation from the bool solver *)
|
||||
|
|
@ -98,21 +103,21 @@ let[@inline] final_check (self:t) (acts:_ Msat.acts) : unit =
|
|||
check_ ~final:true self acts
|
||||
|
||||
let push_level (self:t) : unit =
|
||||
C_clos.push_level (cc self);
|
||||
CC.push_level (cc self);
|
||||
theories self (fun (Th_state ((module Th), st)) -> Th.push_level st)
|
||||
|
||||
let pop_levels (self:t) n : unit =
|
||||
C_clos.pop_levels (cc self) n;
|
||||
CC.pop_levels (cc self) n;
|
||||
theories self (fun (Th_state ((module Th), st)) -> Th.pop_levels st n)
|
||||
|
||||
let mk_model (self:t) lits : Model.t =
|
||||
let m =
|
||||
Sequence.fold
|
||||
(fun m (Th_state ((module Th),st)) -> Model.merge m (Th.mk_model st lits))
|
||||
(fun m (Th_state ((module Th),st)) -> Th.mk_model st lits m)
|
||||
Model.empty (theories self)
|
||||
in
|
||||
(* now complete model using CC *)
|
||||
Congruence_closure.mk_model (cc self) m
|
||||
CC.mk_model (cc self) m
|
||||
|
||||
(** {2 Interface to Congruence Closure} *)
|
||||
|
||||
|
|
@ -131,16 +136,16 @@ let create () : t =
|
|||
cc = lazy (
|
||||
(* lazily tie the knot *)
|
||||
let on_merge = on_merge_from_cc self in
|
||||
C_clos.create ~on_merge ~size:`Big self.tst;
|
||||
CC.create ~on_merge ~size:`Big self.tst;
|
||||
);
|
||||
theories = [];
|
||||
} in
|
||||
ignore (Lazy.force @@ self.cc : C_clos.t);
|
||||
ignore (Lazy.force @@ self.cc : CC.t);
|
||||
self
|
||||
|
||||
let check_invariants (self:t) =
|
||||
if Util._CHECK_INVARIANTS then (
|
||||
Congruence_closure.check_invariants (cc self);
|
||||
CC.check_invariants (cc self);
|
||||
)
|
||||
|
||||
let add_theory (self:t) (th:Theory.t) : unit =
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ include Msat.Solver_intf.PLUGIN_CDCL_T
|
|||
|
||||
val create : unit -> t
|
||||
|
||||
val cc : t -> Congruence_closure.t
|
||||
val cc : t -> CC.t
|
||||
val tst : t -> Term.state
|
||||
|
||||
type theory_state =
|
||||
|
|
|
|||
|
|
@ -19,3 +19,5 @@ let equal = eq_value
|
|||
let hash = hash_value
|
||||
let pp = pp_value
|
||||
|
||||
let fresh (t:term) : t =
|
||||
mk_elt (ID.makef "v_%d" t.term_id) t.term_ty
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ val is_bool : t -> bool
|
|||
val is_true : t -> bool
|
||||
val is_false : t -> bool
|
||||
|
||||
val fresh : Term.t -> t
|
||||
|
||||
include Intf.EQ with type t := t
|
||||
include Intf.HASH with type t := t
|
||||
include Intf.PRINT with type t := t
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@
|
|||
(library
|
||||
(name Sidekick_smt)
|
||||
(public_name sidekick.smt)
|
||||
(libraries containers containers.data sequence sidekick.util msat zarith)
|
||||
(libraries containers containers.data sequence
|
||||
sidekick.util sidekick.cc msat zarith)
|
||||
(flags :standard -warn-error -a+8
|
||||
-color always -safe-string -short-paths -open Sidekick_util)
|
||||
(ocamlopt_flags :standard -O3 -color always
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ let id_not = ID.make "not"
|
|||
let id_and = ID.make "and"
|
||||
let id_or = ID.make "or"
|
||||
let id_imply = ID.make "=>"
|
||||
let id_eq = ID.make "="
|
||||
let id_distinct = ID.make "distinct"
|
||||
|
||||
type 'a view =
|
||||
|
|
@ -32,8 +31,6 @@ exception Not_a_th_term
|
|||
let view_id cst_id args =
|
||||
if ID.equal cst_id id_not && IArray.length args=1 then (
|
||||
B_not (IArray.get args 0)
|
||||
) else if ID.equal cst_id id_eq && IArray.length args=2 then (
|
||||
B_eq (IArray.get args 0, IArray.get args 1)
|
||||
) else if ID.equal cst_id id_and then (
|
||||
B_and args
|
||||
) else if ID.equal cst_id id_or then (
|
||||
|
|
@ -45,13 +42,14 @@ let view_id cst_id args =
|
|||
) else if ID.equal cst_id id_distinct then (
|
||||
B_distinct args
|
||||
) else (
|
||||
raise Not_a_th_term
|
||||
raise_notrace Not_a_th_term
|
||||
)
|
||||
|
||||
let view (t:Term.t) : term view =
|
||||
match Term.view t with
|
||||
| Eq (a,b) -> B_eq (a,b)
|
||||
| App_cst ({cst_id; _}, args) ->
|
||||
(try view_id cst_id args with Not_a_th_term -> B_atom t)
|
||||
begin try view_id cst_id args with Not_a_th_term -> B_atom t end
|
||||
| _ -> B_atom t
|
||||
|
||||
|
||||
|
|
@ -59,9 +57,6 @@ module C = struct
|
|||
|
||||
let get_ty _ _ = Ty.prop
|
||||
|
||||
(* no congruence closure, except for `=` *)
|
||||
let relevant id _ _ = ID.equal id_eq id
|
||||
|
||||
let abs ~self _a =
|
||||
match Term.view self with
|
||||
| App_cst ({cst_id;_}, args) when ID.equal cst_id id_not && IArray.length args=1 ->
|
||||
|
|
@ -89,6 +84,9 @@ module C = struct
|
|||
| B_not _ | B_and _ | B_or _ | B_imply _
|
||||
-> Error.errorf "non boolean value in boolean connective"
|
||||
|
||||
(* no congruence closure for boolean terms *)
|
||||
let relevant _id _ _ = false
|
||||
|
||||
let mk_cst ?(do_cc=false) id : Cst.t =
|
||||
{cst_id=id;
|
||||
cst_view=Cst_def {
|
||||
|
|
@ -98,7 +96,6 @@ module C = struct
|
|||
let and_ = mk_cst id_and
|
||||
let or_ = mk_cst id_or
|
||||
let imply = mk_cst id_imply
|
||||
let eq = mk_cst ~do_cc:true id_eq
|
||||
let distinct = mk_cst id_distinct
|
||||
end
|
||||
|
||||
|
|
@ -134,13 +131,7 @@ let or_l st l =
|
|||
let and_ st a b = and_l st [a;b]
|
||||
let or_ st a b = or_l st [a;b]
|
||||
|
||||
let eq st a b =
|
||||
if Term.equal a b then (
|
||||
Term.true_ st
|
||||
) else (
|
||||
let a,b = if Term.id a > Term.id b then b, a else a, b in
|
||||
Term.app_cst st C.eq (IArray.doubleton a b)
|
||||
)
|
||||
let eq = Term.eq
|
||||
|
||||
let not_ st a =
|
||||
match as_id id_not a, Term.view a with
|
||||
|
|
@ -164,7 +155,7 @@ let distinct st = function
|
|||
module Lit = struct
|
||||
include Lit
|
||||
let eq tst a b = Lit.atom ~sign:true (eq tst a b)
|
||||
let neq tst a b = Lit.atom ~sign:false (neq tst a b)
|
||||
let neq tst a b = neg @@ eq tst a b
|
||||
end
|
||||
|
||||
type t = {
|
||||
|
|
@ -175,14 +166,8 @@ let tseitin (_self:t) (acts:Theory.actions) (lit:Lit.t) (lit_t:term) (v:term vie
|
|||
let (module A) = acts in
|
||||
Log.debugf 5 (fun k->k "(@[th_bool.tseitin@ %a@])" Lit.pp lit);
|
||||
match v with
|
||||
| B_atom _ -> ()
|
||||
| B_not _ -> assert false (* normalized *)
|
||||
| B_eq (t,u) ->
|
||||
if Lit.sign lit then (
|
||||
A.propagate_eq t u [lit]
|
||||
) else (
|
||||
A.propagate_distinct [t;u] ~neq:lit_t lit
|
||||
)
|
||||
| B_atom _ | B_eq _ -> () (* CC will manage *)
|
||||
| B_distinct l ->
|
||||
let l = IArray.to_list l in
|
||||
if Lit.sign lit then (
|
||||
|
|
@ -197,7 +182,7 @@ let tseitin (_self:t) (acts:Theory.actions) (lit:Lit.t) (lit_t:term) (v:term vie
|
|||
IArray.iter
|
||||
(fun sub ->
|
||||
let sublit = Lit.atom sub in
|
||||
A.propagate sublit [lit])
|
||||
A.add_local_axiom [Lit.neg lit; sublit])
|
||||
subs
|
||||
) else (
|
||||
(* propagate [¬lit => ∨_i ¬ subs_i] *)
|
||||
|
|
@ -216,7 +201,7 @@ let tseitin (_self:t) (acts:Theory.actions) (lit:Lit.t) (lit_t:term) (v:term vie
|
|||
IArray.iter
|
||||
(fun sub ->
|
||||
let sublit = Lit.atom ~sign:false sub in
|
||||
A.propagate sublit [lit])
|
||||
A.add_local_axiom [Lit.neg lit; sublit])
|
||||
subs
|
||||
)
|
||||
| B_imply (guard,concl) ->
|
||||
|
|
@ -239,7 +224,7 @@ let tseitin (_self:t) (acts:Theory.actions) (lit:Lit.t) (lit_t:term) (v:term vie
|
|||
let partial_check (self:t) acts (lits:Lit.t Sequence.t) =
|
||||
lits
|
||||
(fun lit ->
|
||||
let t = Lit.view lit in
|
||||
let t = Lit.term lit in
|
||||
match view t with
|
||||
| B_atom _ -> ()
|
||||
| v -> tseitin self acts lit t v)
|
||||
|
|
|
|||
|
|
@ -9,3 +9,4 @@ module Backtrack_stack = Backtrack_stack
|
|||
module Error = Error
|
||||
module IArray = IArray
|
||||
module Intf = Intf
|
||||
module Bag = Bag
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue