mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-07 03:35:38 -05:00
refactor(cc): no micro theories, only callbacks
This commit is contained in:
parent
632bec0e66
commit
ddde590ffd
7 changed files with 217 additions and 412 deletions
|
|
@ -1,115 +0,0 @@
|
|||
|
||||
(** {1 Types used by the congruence closure} *)
|
||||
|
||||
type ('f, 't, 'ts) view =
|
||||
| Bool of bool
|
||||
| App_fun of 'f * 'ts
|
||||
| App_ho of 't * 'ts
|
||||
| If of 't * 't * 't
|
||||
| Eq of 't * 't
|
||||
| Not of 't
|
||||
| Opaque of 't (* do not enter *)
|
||||
|
||||
let[@inline] map_view ~f_f ~f_t ~f_ts (v:_ view) : _ view =
|
||||
match v with
|
||||
| Bool b -> Bool b
|
||||
| App_fun (f, args) -> App_fun (f_f f, f_ts args)
|
||||
| App_ho (f, args) -> App_ho (f_t f, f_ts args)
|
||||
| Not t -> Not (f_t t)
|
||||
| If (a,b,c) -> If (f_t a, f_t b, f_t c)
|
||||
| Eq (a,b) -> Eq (f_t a, f_t b)
|
||||
| Opaque t -> Opaque (f_t t)
|
||||
|
||||
let iter_view ~f_f ~f_t ~f_ts (v:_ view) : unit =
|
||||
match v with
|
||||
| Bool _ -> ()
|
||||
| App_fun (f, args) -> f_f f; f_ts args
|
||||
| App_ho (f, args) -> f_t f; f_ts args
|
||||
| Not t -> f_t t
|
||||
| If (a,b,c) -> f_t a; f_t b; f_t c;
|
||||
| Eq (a,b) -> f_t a; f_t b
|
||||
| Opaque t -> f_t t
|
||||
|
||||
module type TERM = sig
|
||||
module Fun : sig
|
||||
type t
|
||||
val equal : t -> t -> bool
|
||||
val hash : t -> int
|
||||
val pp : t Fmt.printer
|
||||
end
|
||||
|
||||
module Term : sig
|
||||
type t
|
||||
val equal : t -> t -> bool
|
||||
val hash : t -> int
|
||||
val pp : t Fmt.printer
|
||||
|
||||
type state
|
||||
|
||||
val bool : state -> bool -> t
|
||||
|
||||
(** View the term through the lens of the congruence closure *)
|
||||
val cc_view : t -> (Fun.t, t, t Iter.t) view
|
||||
end
|
||||
end
|
||||
|
||||
module type TERM_LIT = sig
|
||||
include TERM
|
||||
|
||||
module Lit : sig
|
||||
type t
|
||||
val neg : t -> t
|
||||
val equal : t -> t -> bool
|
||||
val hash : t -> int
|
||||
val pp : t Fmt.printer
|
||||
|
||||
val sign : t -> bool
|
||||
val term : t -> Term.t
|
||||
end
|
||||
end
|
||||
|
||||
module type FULL = sig
|
||||
include TERM_LIT
|
||||
|
||||
module Proof : sig
|
||||
type t
|
||||
val pp : t Fmt.printer
|
||||
|
||||
val default : t
|
||||
(* TODO: to give more details
|
||||
val cc_lemma : unit -> t
|
||||
*)
|
||||
end
|
||||
|
||||
module Ty : sig
|
||||
type t
|
||||
|
||||
val equal : t -> t -> bool
|
||||
val hash : t -> int
|
||||
val pp : t Fmt.printer
|
||||
end
|
||||
|
||||
module Value : sig
|
||||
type t
|
||||
|
||||
val pp : t Fmt.printer
|
||||
|
||||
val fresh : Term.t -> t
|
||||
|
||||
val true_ : t
|
||||
val false_ : t
|
||||
end
|
||||
|
||||
module Model : sig
|
||||
type t
|
||||
|
||||
val pp : t Fmt.printer
|
||||
|
||||
val eval : t -> Term.t -> Value.t option
|
||||
(** Evaluate the term in the current model *)
|
||||
|
||||
val add : Term.t -> Value.t -> t -> t
|
||||
end
|
||||
end
|
||||
|
||||
(* TODO: micro theory *)
|
||||
|
|
@ -1,66 +1,9 @@
|
|||
|
||||
open CC_types
|
||||
open Congruence_closure_intf
|
||||
|
||||
module type ARG = Congruence_closure_intf.ARG
|
||||
module type S = Congruence_closure_intf.S
|
||||
|
||||
module type THEORY_KEY = Congruence_closure_intf.THEORY_KEY
|
||||
|
||||
(** Custom keys for theory data.
|
||||
This imitates the classic tricks for heterogeneous maps
|
||||
https://blog.janestreet.com/a-universal-type/
|
||||
|
||||
It needs to form a commutative monoid where values are persistent so
|
||||
they can be restored during backtracking.
|
||||
*)
|
||||
module Key = struct
|
||||
module type KEY_IMPL = sig
|
||||
type term
|
||||
type lit
|
||||
type t
|
||||
val id : int
|
||||
val name : string
|
||||
val pp : t Fmt.printer
|
||||
val equal : t -> t -> bool
|
||||
val merge : t -> t -> t
|
||||
exception Store of t
|
||||
end
|
||||
|
||||
type ('term,'lit,'a) t =
|
||||
(module KEY_IMPL with type term = 'term and type lit = 'lit and type t = 'a)
|
||||
|
||||
let n_ = ref 0
|
||||
|
||||
let create (type term)(type lit)(type d)
|
||||
?(pp=fun out _ -> Fmt.string out "<opaque>")
|
||||
~name ~eq ~merge () : (term,lit,d) t =
|
||||
let module K = struct
|
||||
type nonrec term = term
|
||||
type nonrec lit = lit
|
||||
type t = d
|
||||
let id = !n_
|
||||
let name = name
|
||||
let pp = pp
|
||||
let merge = merge
|
||||
let equal = eq
|
||||
exception Store of d
|
||||
end in
|
||||
incr n_;
|
||||
(module K)
|
||||
|
||||
let[@inline] id
|
||||
: type term lit a. (term,lit,a) t -> int
|
||||
= fun (module K) -> K.id
|
||||
|
||||
let[@inline] equal
|
||||
: type term lit a b. (term,lit,a) t -> (term,lit,b) t -> bool
|
||||
= fun (module K1) (module K2) -> K1.id = K2.id
|
||||
|
||||
let pp
|
||||
: type term lit a. (term,lit,a) t Fmt.printer
|
||||
= fun out (module K) -> Fmt.string out K.name
|
||||
end
|
||||
|
||||
module Bits = CCBitField.Make()
|
||||
|
||||
let field_is_pending = Bits.mk_field()
|
||||
|
|
@ -81,6 +24,7 @@ module Make(A: ARG) = struct
|
|||
type fun_ = A.Fun.t
|
||||
type proof = A.Proof.t
|
||||
type model = A.Model.t
|
||||
type th_data = A.Data.t
|
||||
|
||||
(** Actions available to the theory *)
|
||||
type sat_actions = (Msat.void, lit, Msat.void, proof) Msat.acts
|
||||
|
|
@ -88,60 +32,6 @@ module Make(A: ARG) = struct
|
|||
module T = A.Term
|
||||
module Fun = A.Fun
|
||||
module Key = Key
|
||||
module IM = Map.Make(CCInt)
|
||||
|
||||
(** Map for theory data associated with representatives *)
|
||||
module K_map = struct
|
||||
type 'a key = (term,lit,'a) Key.t
|
||||
type pair = Pair : 'a key * exn -> pair
|
||||
|
||||
type t = pair IM.t
|
||||
|
||||
let empty = IM.empty
|
||||
|
||||
let[@inline] mem k t = IM.mem (Key.id k) t
|
||||
|
||||
let find (type a) (k : a key) (self:t) : a option =
|
||||
let (module K) = k in
|
||||
match IM.find K.id self with
|
||||
| Pair (_, K.Store v) -> Some v
|
||||
| _ -> None
|
||||
| exception Not_found -> None
|
||||
|
||||
let add (type a) (k : a key) (v:a) (self:t) : t =
|
||||
let (module K) = k in
|
||||
IM.add K.id (Pair (k, K.Store v)) self
|
||||
|
||||
let remove (type a) (k: a key) self : t =
|
||||
let (module K) = k in
|
||||
IM.remove K.id self
|
||||
|
||||
let equal (m1:t) (m2:t) : bool =
|
||||
IM.equal
|
||||
(fun p1 p2 ->
|
||||
let Pair ((module K1), v1) = p1 in
|
||||
let Pair ((module K2), v2) = p2 in
|
||||
assert (K1.id = K2.id);
|
||||
match v1, v2 with K1.Store v1, K1.Store v2 -> K1.equal v1 v2 | _ -> false)
|
||||
m1 m2
|
||||
|
||||
let merge ~f_both (m1:t) (m2:t) : t =
|
||||
IM.merge
|
||||
(fun _ p1 p2 ->
|
||||
match p1, p2 with
|
||||
| None, None -> None
|
||||
| Some v, None
|
||||
| None, Some v -> Some v
|
||||
| Some (Pair ((module K1) as key1, pair1)), Some (Pair (_, pair2)) ->
|
||||
match pair1, pair2 with
|
||||
| K1.Store v1, K1.Store v2 ->
|
||||
f_both K1.id pair1 pair2; (* callback for checking compat *)
|
||||
let v12 = K1.merge v1 v2 in (* merge content *)
|
||||
Some (Pair (key1, K1.Store v12))
|
||||
| _ -> assert false
|
||||
)
|
||||
m1 m2
|
||||
end
|
||||
|
||||
(** A node of the congruence closure.
|
||||
An equivalence class is represented by its "root" element,
|
||||
|
|
@ -156,7 +46,7 @@ module Make(A: ARG) = struct
|
|||
mutable n_size: int; (* size of the class *)
|
||||
mutable n_as_lit: lit option; (* TODO: put into payload? and only in root? *)
|
||||
mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *)
|
||||
mutable n_th_data: K_map.t; (* theory data *)
|
||||
mutable n_th_data: th_data; (* theory data *)
|
||||
}
|
||||
|
||||
and signature = (fun_, node, node list) view
|
||||
|
|
@ -173,8 +63,9 @@ module Make(A: ARG) = struct
|
|||
| E_reduction (* by pure reduction, tautologically equal *)
|
||||
| E_lit of lit (* because of this literal *)
|
||||
| E_merge of node * node
|
||||
| E_list of explanation list
|
||||
| E_merge_t of term * term
|
||||
| E_congruence of node * node (* caused by normal congruence *)
|
||||
| E_and of explanation * explanation
|
||||
|
||||
type repr = node
|
||||
type conflict = lit list
|
||||
|
|
@ -182,11 +73,12 @@ module Make(A: ARG) = struct
|
|||
module N = struct
|
||||
type t = node
|
||||
|
||||
let[@inline] equal (n1:t) n2 = T.equal n1.n_term n2.n_term
|
||||
let[@inline] equal (n1:t) n2 = n1 == n2
|
||||
let[@inline] hash n = T.hash n.n_term
|
||||
let[@inline] term n = n.n_term
|
||||
let[@inline] pp out n = T.pp out n.n_term
|
||||
let[@inline] as_lit n = n.n_as_lit
|
||||
let[@inline] th_data n = n.n_th_data
|
||||
|
||||
let make (t:term) : t =
|
||||
let rec n = {
|
||||
|
|
@ -199,7 +91,7 @@ module Make(A: ARG) = struct
|
|||
n_expl=FL_none;
|
||||
n_next=n;
|
||||
n_size=1;
|
||||
n_th_data=K_map.empty;
|
||||
n_th_data=A.Data.empty;
|
||||
} in
|
||||
n
|
||||
|
||||
|
|
@ -214,7 +106,7 @@ module Make(A: ARG) = struct
|
|||
in
|
||||
aux n
|
||||
|
||||
let iter_class n =
|
||||
let[@inline] iter_class n =
|
||||
assert (is_root n);
|
||||
iter_class_ n
|
||||
|
||||
|
|
@ -224,6 +116,11 @@ module Make(A: ARG) = struct
|
|||
|
||||
let[@inline] get_field f t = Bits.get f t.n_bits
|
||||
let[@inline] set_field f b t = t.n_bits <- Bits.set f b t.n_bits
|
||||
|
||||
let[@inline] get_field_usr1 t = get_field field_usr1 t
|
||||
let[@inline] set_field_usr1 t b = set_field field_usr1 b t
|
||||
let[@inline] get_field_usr2 t = get_field field_usr2 t
|
||||
let[@inline] set_field_usr2 t b = set_field field_usr2 b t
|
||||
end
|
||||
|
||||
module N_tbl = CCHashtbl.Make(N)
|
||||
|
|
@ -236,19 +133,25 @@ module Make(A: ARG) = struct
|
|||
| E_lit lit -> A.Lit.pp out lit
|
||||
| E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2
|
||||
| E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b
|
||||
| E_list l ->
|
||||
Format.fprintf out "(@[<hv1>and@ %a@])"
|
||||
Fmt.(list ~sep:(return "@ ") @@ within "[" "]" @@ hvbox @@ pp) l
|
||||
| E_merge_t (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" T.pp a T.pp b
|
||||
| E_and (a,b) ->
|
||||
Format.fprintf out "(@[<hv1>and@ %a@ %a@])" pp a pp b
|
||||
|
||||
let mk_reduction : t = E_reduction
|
||||
let[@inline] mk_congruence n1 n2 : t = E_congruence (n1,n2)
|
||||
let[@inline] mk_merge a b : t = E_merge (a,b)
|
||||
let[@inline] mk_merge a b : t = if N.equal a b then mk_reduction else E_merge (a,b)
|
||||
let[@inline] mk_merge_t a b : t = if T.equal a b then mk_reduction else E_merge_t (a,b)
|
||||
let[@inline] mk_lit l : t = E_lit l
|
||||
let mk_list l =
|
||||
|
||||
let rec mk_list l =
|
||||
match l with
|
||||
| [] -> mk_reduction
|
||||
| [x] -> x
|
||||
| l -> E_list l
|
||||
| E_reduction :: tl -> mk_list tl
|
||||
| x :: y ->
|
||||
match mk_list y with
|
||||
| E_reduction -> x
|
||||
| y' -> E_and (x,y')
|
||||
end
|
||||
|
||||
(** A signature is a shallow term shape where immediate subterms
|
||||
|
|
@ -302,15 +205,6 @@ module Make(A: ARG) = struct
|
|||
type combine_task =
|
||||
| CT_merge of node * node * explanation
|
||||
|
||||
module type THEORY = sig
|
||||
type cc
|
||||
type data
|
||||
val key_id : int
|
||||
val key : (term,lit,data) Key.t
|
||||
val on_merge : cc -> N.t -> data -> N.t -> data -> Expl.t -> unit
|
||||
val on_new_term: cc -> N.t -> term -> data option
|
||||
end
|
||||
|
||||
type t = {
|
||||
tst: term_state;
|
||||
tbl: node T_tbl.t;
|
||||
|
|
@ -326,8 +220,8 @@ module Make(A: ARG) = struct
|
|||
pending: node Vec.t;
|
||||
combine: combine_task Vec.t;
|
||||
undo: (unit -> unit) Backtrack_stack.t;
|
||||
mutable theories: theory IM.t;
|
||||
mutable on_merge: (t -> N.t -> N.t -> Expl.t -> unit) list;
|
||||
mutable on_merge: ev_on_merge list;
|
||||
mutable on_new_term: ev_on_new_term list;
|
||||
mutable ps_lits: lit list; (* TODO: thread it around instead? *)
|
||||
(* proof state *)
|
||||
ps_queue: (node*node) Vec.t;
|
||||
|
|
@ -344,9 +238,8 @@ module Make(A: ARG) = struct
|
|||
several times.
|
||||
See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *)
|
||||
|
||||
and theory = (module THEORY with type cc = t)
|
||||
|
||||
type cc = t
|
||||
and ev_on_merge = t -> N.t -> th_data -> N.t -> th_data -> Expl.t -> unit
|
||||
and ev_on_new_term = t -> N.t -> term -> th_data -> th_data option
|
||||
|
||||
let[@inline] size_ (r:repr) = r.n_size
|
||||
let[@inline] true_ cc = Lazy.force cc.true_
|
||||
|
|
@ -359,10 +252,6 @@ module Make(A: ARG) = struct
|
|||
(* check if [t] is in the congruence closure.
|
||||
Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *)
|
||||
let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t
|
||||
(* FIXME
|
||||
let on_merge cc f = cc.on_merge <- f :: cc.on_merge
|
||||
let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term
|
||||
*)
|
||||
|
||||
(* find representative, recursively *)
|
||||
let[@unroll 2] rec find_rec (n:node) : repr =
|
||||
|
|
@ -408,7 +297,7 @@ module Make(A: ARG) = struct
|
|||
|
||||
(* compute up-to-date signature *)
|
||||
let update_sig (s:signature) : Signature.t =
|
||||
CC_types.map_view s
|
||||
Congruence_closure_intf.map_view s
|
||||
~f_f:(fun x->x)
|
||||
~f_t:find_
|
||||
~f_ts:(List.map find_)
|
||||
|
|
@ -475,7 +364,7 @@ module Make(A: ARG) = struct
|
|||
| FL_none -> 0
|
||||
| FL_some {next=t'; _} -> 1 + distance_to_root t'
|
||||
|
||||
(* TODO: bool flag on nodes + stepwise progress + cleanup *)
|
||||
(* TODO: new bool flag on nodes + stepwise progress + cleanup *)
|
||||
(* find the closest common ancestor of [a] and [b] in the proof forest *)
|
||||
let find_common_ancestor (a:node) (b:node) : node =
|
||||
let d_a = distance_to_root a in
|
||||
|
|
@ -505,13 +394,11 @@ module Make(A: ARG) = struct
|
|||
let[@inline] ps_add_obligation (cc:t) a b = Vec.push cc.ps_queue (a,b)
|
||||
let[@inline] ps_add_lit ps l = ps.ps_lits <- l :: ps.ps_lits
|
||||
|
||||
(* TODO: remove *)
|
||||
let ps_clear (cc:t) =
|
||||
cc.ps_lits <- [];
|
||||
Vec.clear cc.ps_queue;
|
||||
()
|
||||
|
||||
(* TODO: turn this into a fold? *)
|
||||
(* decompose explanation [e] of why [n1 = n2] *)
|
||||
let rec decompose_explain cc (e:explanation) : unit =
|
||||
Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e);
|
||||
|
|
@ -536,7 +423,14 @@ module Make(A: ARG) = struct
|
|||
end
|
||||
| E_lit lit -> ps_add_lit cc lit
|
||||
| E_merge (a,b) -> ps_add_obligation cc a b
|
||||
| E_list l -> List.iter (decompose_explain cc) l
|
||||
| E_merge_t (a,b) ->
|
||||
(* find nodes for [a] and [b] on the fly *)
|
||||
begin match T_tbl.find cc.tbl a, T_tbl.find cc.tbl b with
|
||||
| a, b -> ps_add_obligation cc a b
|
||||
| exception Not_found ->
|
||||
Error.errorf "expl: cannot find node(s) for %a, %a" T.pp a T.pp b
|
||||
end
|
||||
| E_and (a,b) -> decompose_explain cc a; decompose_explain cc b
|
||||
|
||||
(* explain why [a = parent_a], where [a -> ... -> parent_a] in the
|
||||
proof forest *)
|
||||
|
|
@ -565,7 +459,6 @@ module Make(A: ARG) = struct
|
|||
done;
|
||||
cc.ps_lits
|
||||
|
||||
(* TODO: do not use ps_lits anymore *)
|
||||
let explain_eq_n ?(init=[]) cc (n1:node) (n2:node) : lit list =
|
||||
ps_clear cc;
|
||||
cc.ps_lits <- init;
|
||||
|
|
@ -604,15 +497,15 @@ module Make(A: ARG) = struct
|
|||
push_pending cc n;
|
||||
);
|
||||
(* initial theory data *)
|
||||
let th_map =
|
||||
IM.fold
|
||||
(fun _ (module Th: THEORY with type cc=cc) th_map ->
|
||||
match Th.on_new_term cc n t with
|
||||
| None -> th_map
|
||||
| Some v -> K_map.add Th.key v th_map)
|
||||
cc.theories K_map.empty
|
||||
let th_data =
|
||||
List.fold_left
|
||||
(fun data f ->
|
||||
match f cc n t data with
|
||||
| None -> data
|
||||
| Some d -> d)
|
||||
A.Data.empty cc.on_new_term
|
||||
in
|
||||
n.n_th_data <- th_map;
|
||||
n.n_th_data <- th_data;
|
||||
n
|
||||
|
||||
(* compute the initial signature of the given node *)
|
||||
|
|
@ -754,36 +647,19 @@ module Make(A: ARG) = struct
|
|||
merge_bool rb b ra a;
|
||||
(* perform [union r_from r_into] *)
|
||||
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
|
||||
(* call [on_merge] functions *)
|
||||
List.iter (fun f -> f cc r_into r_from e_ab) cc.on_merge;
|
||||
(* call micro theories *)
|
||||
(* call [on_merge] functions, and merge theory data items *)
|
||||
begin
|
||||
let th_into = r_into.n_th_data in
|
||||
let th_from = r_from.n_th_data in
|
||||
(* merge the two maps; if a key occurs in both, looks for theories with
|
||||
this particular key *)
|
||||
let th =
|
||||
K_map.merge th_into th_from
|
||||
~f_both:(fun id pair_into pair_from ->
|
||||
match IM.find id cc.theories with
|
||||
| (module Th : THEORY with type cc=t) ->
|
||||
(* casting magic *)
|
||||
let (module K) = Th.key in
|
||||
begin match pair_into, pair_from with
|
||||
| K.Store v_into, K.Store v_from ->
|
||||
Log.debugf 15
|
||||
(fun k->k "(@[cc.merge.th-on-merge@ :th %s@])" K.name);
|
||||
(* FIXME: explanation is a=ra, e_ab, b=rb *)
|
||||
Th.on_merge cc r_into v_into r_from v_from e_ab
|
||||
| _ -> assert false
|
||||
end
|
||||
| exception Not_found -> ())
|
||||
in
|
||||
let new_data = A.Data.merge th_into th_from in
|
||||
(* restore old data, if it changed *)
|
||||
if not @@ K_map.equal th th_into then (
|
||||
if new_data != th_into then (
|
||||
on_backtrack cc (fun () -> r_into.n_th_data <- th_into);
|
||||
);
|
||||
r_into.n_th_data <- th;
|
||||
r_into.n_th_data <- new_data;
|
||||
(* explanation is [a=ra & e_ab & b=rb] *)
|
||||
let expl = Expl.mk_list [e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb] in
|
||||
List.iter (fun f -> f cc r_into th_into r_from th_from expl) cc.on_merge;
|
||||
end;
|
||||
begin
|
||||
(* parents might have a different signature, check for collisions *)
|
||||
|
|
@ -864,8 +740,6 @@ module Make(A: ARG) = struct
|
|||
|
||||
module Theory = struct
|
||||
type cc = t
|
||||
type t = theory
|
||||
type 'a key = (term,lit,'a) Key.t
|
||||
|
||||
(* raise a conflict *)
|
||||
let raise_conflict cc expl =
|
||||
|
|
@ -879,41 +753,6 @@ module Make(A: ARG) = struct
|
|||
merge_classes cc n1 n2 expl
|
||||
|
||||
let add_term = add_term
|
||||
|
||||
let get_data _cc n key =
|
||||
assert (N.is_root n);
|
||||
K_map.find key n.n_th_data
|
||||
|
||||
(* FIXME: call micro theory here? in case of merge *)
|
||||
(* update data for [n] *)
|
||||
let add_data (type a) (self:cc) (n:node) (key: a key) (v:a) : unit =
|
||||
let n = find_ n in
|
||||
let map = n.n_th_data in
|
||||
let old_v = K_map.find key map in
|
||||
let v', is_diff = match old_v with
|
||||
| None -> v, true
|
||||
| Some old_v ->
|
||||
let (module K) = key in
|
||||
let v' = K.merge old_v v in
|
||||
v', K.equal v v'
|
||||
in
|
||||
if is_diff then (
|
||||
on_backtrack self (fun () -> n.n_th_data <- map);
|
||||
);
|
||||
n.n_th_data <- K_map.add key v' map;
|
||||
()
|
||||
|
||||
let make (type a) ~(key:a key) ~on_merge ~on_new_term () : t =
|
||||
let module Th = struct
|
||||
type nonrec cc = cc
|
||||
type data = a
|
||||
let key = key
|
||||
let key_id = Key.id key
|
||||
let on_merge = on_merge
|
||||
let on_new_term = on_new_term
|
||||
end in
|
||||
(module Th : THEORY with type cc=cc)
|
||||
|
||||
end
|
||||
|
||||
let check_invariants_ (cc:t) =
|
||||
|
|
@ -1000,25 +839,18 @@ module Make(A: ARG) = struct
|
|||
let n2 = add_term cc t2 in
|
||||
merge_classes cc n1 n2 expl
|
||||
|
||||
let add_th (self:t) (th:theory) : unit =
|
||||
let (module Th) = th in
|
||||
if IM.mem Th.key_id self.theories then (
|
||||
Error.errorf "attempt to add two theories with key %a" Key.pp Th.key
|
||||
);
|
||||
Log.debugf 3 (fun k->k "(@[@{<green>cc.add-theory@} %a@])" Key.pp Th.key);
|
||||
self.theories <- IM.add Th.key_id th self.theories
|
||||
|
||||
let on_merge cc f = cc.on_merge <- f :: cc.on_merge
|
||||
let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term
|
||||
|
||||
let create ?(stat=Stat.global)
|
||||
?th:(theories=[]) ?(on_merge=[]) ?(size=`Big) (tst:term_state) : t =
|
||||
?(on_merge=[]) ?(on_new_term=[]) ?(size=`Big) (tst:term_state) : t =
|
||||
let size = match size with `Small -> 128 | `Big -> 2048 in
|
||||
let rec cc = {
|
||||
tst;
|
||||
tbl = T_tbl.create size;
|
||||
signatures_tbl = Sig_tbl.create size;
|
||||
theories=IM.empty;
|
||||
on_merge;
|
||||
on_new_term;
|
||||
pending=Vec.create();
|
||||
combine=Vec.create();
|
||||
ps_lits=[];
|
||||
|
|
@ -1037,7 +869,6 @@ module Make(A: ARG) = struct
|
|||
in
|
||||
ignore (Lazy.force true_ : node);
|
||||
ignore (Lazy.force false_ : node);
|
||||
List.iter (add_th cc) theories; (* now add theories *)
|
||||
cc
|
||||
|
||||
let[@inline] find_t cc t : repr =
|
||||
|
|
|
|||
|
|
@ -3,9 +3,6 @@
|
|||
module type ARG = Congruence_closure_intf.ARG
|
||||
module type S = Congruence_closure_intf.S
|
||||
|
||||
module type THEORY_KEY = Congruence_closure_intf.THEORY_KEY
|
||||
module Key : THEORY_KEY
|
||||
|
||||
module Make(A: ARG)
|
||||
: S with type term = A.Term.t
|
||||
and type lit = A.Lit.t
|
||||
|
|
@ -13,4 +10,4 @@ module Make(A: ARG)
|
|||
and type term_state = A.Term.state
|
||||
and type proof = A.Proof.t
|
||||
and type model = A.Model.t
|
||||
and module Key = Key
|
||||
and type th_data = A.Data.t
|
||||
|
|
|
|||
|
|
@ -1,35 +1,124 @@
|
|||
|
||||
module type ARG = CC_types.FULL
|
||||
(** {1 Types used by the congruence closure} *)
|
||||
|
||||
module type THEORY_KEY = sig
|
||||
type ('term,'lit,'a) t
|
||||
(** An access key for theories which have per-class data ['a] *)
|
||||
type ('f, 't, 'ts) view =
|
||||
| Bool of bool
|
||||
| App_fun of 'f * 'ts
|
||||
| App_ho of 't * 'ts
|
||||
| If of 't * 't * 't
|
||||
| Eq of 't * 't
|
||||
| Not of 't
|
||||
| Opaque of 't (* do not enter *)
|
||||
|
||||
val create :
|
||||
?pp:'a Fmt.printer ->
|
||||
name:string ->
|
||||
eq:('a -> 'a -> bool) ->
|
||||
merge:('a -> 'a -> 'a) ->
|
||||
unit ->
|
||||
('term,'lit,'a) t
|
||||
(** Generative creation of keys for the given theory data.
|
||||
let[@inline] map_view ~f_f ~f_t ~f_ts (v:_ view) : _ view =
|
||||
match v with
|
||||
| Bool b -> Bool b
|
||||
| App_fun (f, args) -> App_fun (f_f f, f_ts args)
|
||||
| App_ho (f, args) -> App_ho (f_t f, f_ts args)
|
||||
| Not t -> Not (f_t t)
|
||||
| If (a,b,c) -> If (f_t a, f_t b, f_t c)
|
||||
| Eq (a,b) -> Eq (f_t a, f_t b)
|
||||
| Opaque t -> Opaque (f_t t)
|
||||
|
||||
@param eq : Equality. This is used to optimize backtracking info.
|
||||
let iter_view ~f_f ~f_t ~f_ts (v:_ view) : unit =
|
||||
match v with
|
||||
| Bool _ -> ()
|
||||
| App_fun (f, args) -> f_f f; f_ts args
|
||||
| App_ho (f, args) -> f_t f; f_ts args
|
||||
| Not t -> f_t t
|
||||
| If (a,b,c) -> f_t a; f_t b; f_t c;
|
||||
| Eq (a,b) -> f_t a; f_t b
|
||||
| Opaque t -> f_t t
|
||||
|
||||
@param merge :
|
||||
[merge d1 d2] is called when merging classes with data [d1] and [d2]
|
||||
respectively. The theory should already have checked that the merge
|
||||
is compatible, and this produces the combined data for terms in the
|
||||
merged class.
|
||||
@param name name of the theory which owns this data
|
||||
@param pp a printer for the data
|
||||
*)
|
||||
module type TERM = sig
|
||||
module Fun : sig
|
||||
type t
|
||||
val equal : t -> t -> bool
|
||||
val hash : t -> int
|
||||
val pp : t Fmt.printer
|
||||
end
|
||||
|
||||
val equal : ('t,'lit,_) t -> ('t,'lit,_) t -> bool
|
||||
(** Checks if two keys are equal (generatively) *)
|
||||
module Term : sig
|
||||
type t
|
||||
val equal : t -> t -> bool
|
||||
val hash : t -> int
|
||||
val pp : t Fmt.printer
|
||||
|
||||
val pp : _ t Fmt.printer
|
||||
(** Prints the name of the key. *)
|
||||
type state
|
||||
|
||||
val bool : state -> bool -> t
|
||||
|
||||
(** View the term through the lens of the congruence closure *)
|
||||
val cc_view : t -> (Fun.t, t, t Iter.t) view
|
||||
end
|
||||
end
|
||||
|
||||
module type TERM_LIT = sig
|
||||
include TERM
|
||||
|
||||
module Lit : sig
|
||||
type t
|
||||
val neg : t -> t
|
||||
val equal : t -> t -> bool
|
||||
val hash : t -> int
|
||||
val pp : t Fmt.printer
|
||||
|
||||
val sign : t -> bool
|
||||
val term : t -> Term.t
|
||||
end
|
||||
end
|
||||
|
||||
module type ARG = sig
|
||||
include TERM_LIT
|
||||
|
||||
module Proof : sig
|
||||
type t
|
||||
val pp : t Fmt.printer
|
||||
|
||||
val default : t
|
||||
(* TODO: to give more details
|
||||
val cc_lemma : unit -> t
|
||||
*)
|
||||
end
|
||||
|
||||
module Ty : sig
|
||||
type t
|
||||
|
||||
val equal : t -> t -> bool
|
||||
val hash : t -> int
|
||||
val pp : t Fmt.printer
|
||||
end
|
||||
|
||||
module Value : sig
|
||||
type t
|
||||
|
||||
val pp : t Fmt.printer
|
||||
|
||||
val fresh : Term.t -> t
|
||||
|
||||
val true_ : t
|
||||
val false_ : t
|
||||
end
|
||||
|
||||
module Model : sig
|
||||
type t
|
||||
|
||||
val pp : t Fmt.printer
|
||||
|
||||
val eval : t -> Term.t -> Value.t option
|
||||
(** Evaluate the term in the current model *)
|
||||
|
||||
val add : Term.t -> Value.t -> t -> t
|
||||
end
|
||||
|
||||
(** Monoid embedded in every node *)
|
||||
module Data : sig
|
||||
type t
|
||||
|
||||
val empty : t
|
||||
|
||||
val merge : t -> t -> t
|
||||
end
|
||||
end
|
||||
|
||||
module type S = sig
|
||||
|
|
@ -39,9 +128,7 @@ module type S = sig
|
|||
type lit
|
||||
type proof
|
||||
type model
|
||||
|
||||
(** Implementation of theory keys *)
|
||||
module Key : THEORY_KEY
|
||||
type th_data
|
||||
|
||||
type t
|
||||
(** Global state of the congruence closure *)
|
||||
|
|
@ -80,6 +167,15 @@ module type S = sig
|
|||
val iter_parents : t -> t Iter.t
|
||||
(** Traverse the parents of the class.
|
||||
Invariant: [is_root n] (see {!find} below) *)
|
||||
|
||||
val th_data : t -> th_data
|
||||
(** Access theory data for this node *)
|
||||
|
||||
val get_field_usr1 : t -> bool
|
||||
val set_field_usr1 : t -> bool -> unit
|
||||
|
||||
val get_field_usr2 : t -> bool
|
||||
val set_field_usr2 : t -> bool -> unit
|
||||
end
|
||||
|
||||
module Expl : sig
|
||||
|
|
@ -87,6 +183,7 @@ module type S = sig
|
|||
val pp : t Fmt.printer
|
||||
|
||||
val mk_merge : N.t -> N.t -> t
|
||||
val mk_merge_t : term -> term -> t
|
||||
val mk_lit : lit -> t
|
||||
val mk_list : t list -> t
|
||||
end
|
||||
|
|
@ -117,9 +214,6 @@ module type S = sig
|
|||
|
||||
module Theory : sig
|
||||
type cc = t
|
||||
type t
|
||||
|
||||
type 'a key = (term,lit,'a) Key.t
|
||||
|
||||
val raise_conflict : cc -> Expl.t -> unit
|
||||
(** Raise a conflict with the given explanation
|
||||
|
|
@ -134,39 +228,26 @@ module type S = sig
|
|||
val add_term : cc -> term -> N.t
|
||||
(** Add/retrieve node for this term.
|
||||
To be used in theories *)
|
||||
|
||||
val get_data : cc -> N.t -> 'a key -> 'a option
|
||||
(** Get data information for this particular representative *)
|
||||
|
||||
val add_data : cc -> N.t -> 'a key -> 'a -> unit
|
||||
(** Add data to this particular representative. Will be backtracked. *)
|
||||
|
||||
val make :
|
||||
key:'a key ->
|
||||
on_merge:(cc -> N.t -> 'a -> N.t -> 'a -> Expl.t -> unit) ->
|
||||
on_new_term:(cc -> N.t -> term -> 'a option) ->
|
||||
unit ->
|
||||
t
|
||||
(** Build a micro theory. It can use the callbacks above. *)
|
||||
end
|
||||
|
||||
type ev_on_merge = t -> N.t -> th_data -> N.t -> th_data -> Expl.t -> unit
|
||||
type ev_on_new_term = t -> N.t -> term -> th_data -> th_data option
|
||||
|
||||
val create :
|
||||
?stat:Stat.t ->
|
||||
?th:Theory.t list ->
|
||||
?on_merge:(t -> N.t -> N.t -> Expl.t -> unit) list ->
|
||||
?on_merge:ev_on_merge list ->
|
||||
?on_new_term:ev_on_new_term list ->
|
||||
?size:[`Small | `Big] ->
|
||||
term_state ->
|
||||
t
|
||||
(** Create a new congruence closure. *)
|
||||
|
||||
val add_th : t -> Theory.t -> unit
|
||||
(** Add a (micro) theory to the congruence closure.
|
||||
@raise Error.Error if there is already a theory with
|
||||
the same key. *)
|
||||
|
||||
val on_merge : t -> (t -> N.t -> N.t -> Expl.t -> unit) -> unit
|
||||
val on_merge : t -> ev_on_merge -> unit
|
||||
(** Add a function to be called when two classes are merged *)
|
||||
|
||||
val on_new_term : t -> ev_on_new_term -> unit
|
||||
(** Add a function to be called when a new node is created *)
|
||||
|
||||
val set_as_lit : t -> N.t -> lit -> unit
|
||||
(** map the given node to a literal. *)
|
||||
|
||||
|
|
@ -217,5 +298,4 @@ module type S = sig
|
|||
val check_invariants : t -> unit
|
||||
val pp_full : t Fmt.printer
|
||||
(**/**)
|
||||
|
||||
end
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
|
||||
open Congruence_closure_intf
|
||||
|
||||
type res =
|
||||
| Sat
|
||||
| Unsat
|
||||
|
||||
module type TERM = CC_types.TERM
|
||||
module type TERM = Congruence_closure_intf.TERM
|
||||
|
||||
module type S = sig
|
||||
type term
|
||||
|
|
@ -18,12 +20,12 @@ module type S = sig
|
|||
val distinct : t -> term list -> unit
|
||||
|
||||
val check : t -> res
|
||||
|
||||
val classes : t -> term Iter.t Iter.t
|
||||
end
|
||||
|
||||
|
||||
module Make(A: TERM) = struct
|
||||
open CC_types
|
||||
|
||||
module Fun = A.Fun
|
||||
module T = A.Term
|
||||
type fun_ = A.Fun.t
|
||||
|
|
@ -47,6 +49,8 @@ module Make(A: TERM) = struct
|
|||
let[@inline] equal (n1:t) n2 = T.equal n1.n_t n2.n_t
|
||||
let[@inline] hash (n:t) = T.hash n.n_t
|
||||
let[@inline] size (n:t) = n.n_size
|
||||
let[@inline] is_root n = n == n.n_root
|
||||
let[@inline] term n = n.n_t
|
||||
let pp out n = T.pp out n.n_t
|
||||
|
||||
let add_parent (self:t) ~p : unit =
|
||||
|
|
@ -171,7 +175,7 @@ module Make(A: TERM) = struct
|
|||
(* find representative *)
|
||||
let[@inline] find_ (n:node) : node =
|
||||
let r = n.n_root in
|
||||
assert (Node.equal r.n_root r);
|
||||
assert (Node.is_root r);
|
||||
r
|
||||
|
||||
let find_t_ (self:t) (t:term): node =
|
||||
|
|
@ -313,4 +317,10 @@ module Make(A: TERM) = struct
|
|||
self.ok <- false;
|
||||
Unsat
|
||||
|
||||
let classes self : _ Iter.t =
|
||||
T_tbl.values self.tbl
|
||||
|> Iter.filter Node.is_root
|
||||
|> Iter.map
|
||||
(fun n -> Node.iter_cls n |> Iter.map Node.term)
|
||||
|
||||
end
|
||||
|
|
|
|||
|
|
@ -6,11 +6,13 @@
|
|||
It just decides the satisfiability of a set of (dis)equations.
|
||||
*)
|
||||
|
||||
open Congruence_closure_intf
|
||||
|
||||
type res =
|
||||
| Sat
|
||||
| Unsat
|
||||
|
||||
module type TERM = CC_types.TERM
|
||||
module type TERM = Congruence_closure_intf.TERM
|
||||
|
||||
module type S = sig
|
||||
type term
|
||||
|
|
@ -28,6 +30,10 @@ module type S = sig
|
|||
(** [distinct cc l] asserts that all terms in [l] are distinct *)
|
||||
|
||||
val check : t -> res
|
||||
|
||||
val classes : t -> term Iter.t Iter.t
|
||||
(** Traverse the set of classes in the congruence closure.
|
||||
This should be called only if {!check} returned [Sat]. *)
|
||||
end
|
||||
|
||||
module Make(A: TERM)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
|
||||
type ('f, 't, 'ts) view = ('f, 't, 'ts) CC_types.view =
|
||||
type ('f, 't, 'ts) view = ('f, 't, 'ts) Congruence_closure_intf.view =
|
||||
| Bool of bool
|
||||
| App_fun of 'f * 'ts
|
||||
| App_ho of 't * 'ts
|
||||
|
|
@ -8,16 +8,12 @@ type ('f, 't, 'ts) view = ('f, 't, 'ts) CC_types.view =
|
|||
| Not of 't
|
||||
| Opaque of 't (* do not enter *)
|
||||
|
||||
module CC_types = CC_types
|
||||
|
||||
(** Parameter for the congruence closure *)
|
||||
module type TERM_LIT = CC_types.TERM_LIT
|
||||
module type FULL = CC_types.FULL
|
||||
module type TERM_LIT = Congruence_closure_intf.TERM_LIT
|
||||
module type ARG = Congruence_closure_intf.ARG
|
||||
module type S = Congruence_closure.S
|
||||
|
||||
module Mini_cc = Mini_cc
|
||||
module Congruence_closure = Congruence_closure
|
||||
module Key = Congruence_closure.Key
|
||||
|
||||
module Make = Congruence_closure.Make
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue