From ddde590ffdeb3e612b0e6760152434d1c29ddc6e Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Tue, 2 Apr 2019 21:30:14 -0500 Subject: [PATCH] refactor(cc): no micro theories, only callbacks --- src/cc/CC_types.ml | 115 ------------ src/cc/Congruence_closure.ml | 285 ++++++------------------------ src/cc/Congruence_closure.mli | 5 +- src/cc/Congruence_closure_intf.ml | 188 ++++++++++++++------ src/cc/Mini_cc.ml | 18 +- src/cc/Mini_cc.mli | 8 +- src/cc/Sidekick_cc.ml | 10 +- 7 files changed, 217 insertions(+), 412 deletions(-) delete mode 100644 src/cc/CC_types.ml diff --git a/src/cc/CC_types.ml b/src/cc/CC_types.ml deleted file mode 100644 index 2b75cd7d..00000000 --- a/src/cc/CC_types.ml +++ /dev/null @@ -1,115 +0,0 @@ - -(** {1 Types used by the congruence closure} *) - -type ('f, 't, 'ts) view = - | Bool of bool - | App_fun of 'f * 'ts - | App_ho of 't * 'ts - | If of 't * 't * 't - | Eq of 't * 't - | Not of 't - | Opaque of 't (* do not enter *) - -let[@inline] map_view ~f_f ~f_t ~f_ts (v:_ view) : _ view = - match v with - | Bool b -> Bool b - | App_fun (f, args) -> App_fun (f_f f, f_ts args) - | App_ho (f, args) -> App_ho (f_t f, f_ts args) - | Not t -> Not (f_t t) - | If (a,b,c) -> If (f_t a, f_t b, f_t c) - | Eq (a,b) -> Eq (f_t a, f_t b) - | Opaque t -> Opaque (f_t t) - -let iter_view ~f_f ~f_t ~f_ts (v:_ view) : unit = - match v with - | Bool _ -> () - | App_fun (f, args) -> f_f f; f_ts args - | App_ho (f, args) -> f_t f; f_ts args - | Not t -> f_t t - | If (a,b,c) -> f_t a; f_t b; f_t c; - | Eq (a,b) -> f_t a; f_t b - | Opaque t -> f_t t - -module type TERM = sig - module Fun : sig - type t - val equal : t -> t -> bool - val hash : t -> int - val pp : t Fmt.printer - end - - module Term : sig - type t - val equal : t -> t -> bool - val hash : t -> int - val pp : t Fmt.printer - - type state - - val bool : state -> bool -> t - - (** View the term through the lens of the congruence closure *) - val cc_view : t -> (Fun.t, t, t Iter.t) view - end -end - -module type TERM_LIT = sig - include TERM - - module Lit : sig - type t - val neg : t -> t - val equal : t -> t -> bool - val hash : t -> int - val pp : t Fmt.printer - - val sign : t -> bool - val term : t -> Term.t - end -end - -module type FULL = sig - include TERM_LIT - - module Proof : sig - type t - val pp : t Fmt.printer - - val default : t - (* TODO: to give more details - val cc_lemma : unit -> t - *) - end - - module Ty : sig - type t - - val equal : t -> t -> bool - val hash : t -> int - val pp : t Fmt.printer - end - - module Value : sig - type t - - val pp : t Fmt.printer - - val fresh : Term.t -> t - - val true_ : t - val false_ : t - end - - module Model : sig - type t - - val pp : t Fmt.printer - - val eval : t -> Term.t -> Value.t option - (** Evaluate the term in the current model *) - - val add : Term.t -> Value.t -> t -> t - end -end - -(* TODO: micro theory *) diff --git a/src/cc/Congruence_closure.ml b/src/cc/Congruence_closure.ml index c7993033..99b5f0ff 100644 --- a/src/cc/Congruence_closure.ml +++ b/src/cc/Congruence_closure.ml @@ -1,66 +1,9 @@ -open CC_types +open Congruence_closure_intf module type ARG = Congruence_closure_intf.ARG module type S = Congruence_closure_intf.S -module type THEORY_KEY = Congruence_closure_intf.THEORY_KEY - -(** Custom keys for theory data. - This imitates the classic tricks for heterogeneous maps - https://blog.janestreet.com/a-universal-type/ - - It needs to form a commutative monoid where values are persistent so - they can be restored during backtracking. -*) -module Key = struct - module type KEY_IMPL = sig - type term - type lit - type t - val id : int - val name : string - val pp : t Fmt.printer - val equal : t -> t -> bool - val merge : t -> t -> t - exception Store of t - end - - type ('term,'lit,'a) t = - (module KEY_IMPL with type term = 'term and type lit = 'lit and type t = 'a) - - let n_ = ref 0 - - let create (type term)(type lit)(type d) - ?(pp=fun out _ -> Fmt.string out "") - ~name ~eq ~merge () : (term,lit,d) t = - let module K = struct - type nonrec term = term - type nonrec lit = lit - type t = d - let id = !n_ - let name = name - let pp = pp - let merge = merge - let equal = eq - exception Store of d - end in - incr n_; - (module K) - - let[@inline] id - : type term lit a. (term,lit,a) t -> int - = fun (module K) -> K.id - - let[@inline] equal - : type term lit a b. (term,lit,a) t -> (term,lit,b) t -> bool - = fun (module K1) (module K2) -> K1.id = K2.id - - let pp - : type term lit a. (term,lit,a) t Fmt.printer - = fun out (module K) -> Fmt.string out K.name -end - module Bits = CCBitField.Make() let field_is_pending = Bits.mk_field() @@ -81,6 +24,7 @@ module Make(A: ARG) = struct type fun_ = A.Fun.t type proof = A.Proof.t type model = A.Model.t + type th_data = A.Data.t (** Actions available to the theory *) type sat_actions = (Msat.void, lit, Msat.void, proof) Msat.acts @@ -88,60 +32,6 @@ module Make(A: ARG) = struct module T = A.Term module Fun = A.Fun module Key = Key - module IM = Map.Make(CCInt) - - (** Map for theory data associated with representatives *) - module K_map = struct - type 'a key = (term,lit,'a) Key.t - type pair = Pair : 'a key * exn -> pair - - type t = pair IM.t - - let empty = IM.empty - - let[@inline] mem k t = IM.mem (Key.id k) t - - let find (type a) (k : a key) (self:t) : a option = - let (module K) = k in - match IM.find K.id self with - | Pair (_, K.Store v) -> Some v - | _ -> None - | exception Not_found -> None - - let add (type a) (k : a key) (v:a) (self:t) : t = - let (module K) = k in - IM.add K.id (Pair (k, K.Store v)) self - - let remove (type a) (k: a key) self : t = - let (module K) = k in - IM.remove K.id self - - let equal (m1:t) (m2:t) : bool = - IM.equal - (fun p1 p2 -> - let Pair ((module K1), v1) = p1 in - let Pair ((module K2), v2) = p2 in - assert (K1.id = K2.id); - match v1, v2 with K1.Store v1, K1.Store v2 -> K1.equal v1 v2 | _ -> false) - m1 m2 - - let merge ~f_both (m1:t) (m2:t) : t = - IM.merge - (fun _ p1 p2 -> - match p1, p2 with - | None, None -> None - | Some v, None - | None, Some v -> Some v - | Some (Pair ((module K1) as key1, pair1)), Some (Pair (_, pair2)) -> - match pair1, pair2 with - | K1.Store v1, K1.Store v2 -> - f_both K1.id pair1 pair2; (* callback for checking compat *) - let v12 = K1.merge v1 v2 in (* merge content *) - Some (Pair (key1, K1.Store v12)) - | _ -> assert false - ) - m1 m2 - end (** A node of the congruence closure. An equivalence class is represented by its "root" element, @@ -156,7 +46,7 @@ module Make(A: ARG) = struct mutable n_size: int; (* size of the class *) mutable n_as_lit: lit option; (* TODO: put into payload? and only in root? *) mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *) - mutable n_th_data: K_map.t; (* theory data *) + mutable n_th_data: th_data; (* theory data *) } and signature = (fun_, node, node list) view @@ -173,8 +63,9 @@ module Make(A: ARG) = struct | E_reduction (* by pure reduction, tautologically equal *) | E_lit of lit (* because of this literal *) | E_merge of node * node - | E_list of explanation list + | E_merge_t of term * term | E_congruence of node * node (* caused by normal congruence *) + | E_and of explanation * explanation type repr = node type conflict = lit list @@ -182,11 +73,12 @@ module Make(A: ARG) = struct module N = struct type t = node - let[@inline] equal (n1:t) n2 = T.equal n1.n_term n2.n_term + let[@inline] equal (n1:t) n2 = n1 == n2 let[@inline] hash n = T.hash n.n_term let[@inline] term n = n.n_term let[@inline] pp out n = T.pp out n.n_term let[@inline] as_lit n = n.n_as_lit + let[@inline] th_data n = n.n_th_data let make (t:term) : t = let rec n = { @@ -199,7 +91,7 @@ module Make(A: ARG) = struct n_expl=FL_none; n_next=n; n_size=1; - n_th_data=K_map.empty; + n_th_data=A.Data.empty; } in n @@ -214,7 +106,7 @@ module Make(A: ARG) = struct in aux n - let iter_class n = + let[@inline] iter_class n = assert (is_root n); iter_class_ n @@ -224,6 +116,11 @@ module Make(A: ARG) = struct let[@inline] get_field f t = Bits.get f t.n_bits let[@inline] set_field f b t = t.n_bits <- Bits.set f b t.n_bits + + let[@inline] get_field_usr1 t = get_field field_usr1 t + let[@inline] set_field_usr1 t b = set_field field_usr1 b t + let[@inline] get_field_usr2 t = get_field field_usr2 t + let[@inline] set_field_usr2 t b = set_field field_usr2 b t end module N_tbl = CCHashtbl.Make(N) @@ -236,19 +133,25 @@ module Make(A: ARG) = struct | E_lit lit -> A.Lit.pp out lit | E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2 | E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b - | E_list l -> - Format.fprintf out "(@[and@ %a@])" - Fmt.(list ~sep:(return "@ ") @@ within "[" "]" @@ hvbox @@ pp) l + | E_merge_t (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" T.pp a T.pp b + | E_and (a,b) -> + Format.fprintf out "(@[and@ %a@ %a@])" pp a pp b let mk_reduction : t = E_reduction let[@inline] mk_congruence n1 n2 : t = E_congruence (n1,n2) - let[@inline] mk_merge a b : t = E_merge (a,b) + let[@inline] mk_merge a b : t = if N.equal a b then mk_reduction else E_merge (a,b) + let[@inline] mk_merge_t a b : t = if T.equal a b then mk_reduction else E_merge_t (a,b) let[@inline] mk_lit l : t = E_lit l - let mk_list l = + + let rec mk_list l = match l with | [] -> mk_reduction | [x] -> x - | l -> E_list l + | E_reduction :: tl -> mk_list tl + | x :: y -> + match mk_list y with + | E_reduction -> x + | y' -> E_and (x,y') end (** A signature is a shallow term shape where immediate subterms @@ -302,15 +205,6 @@ module Make(A: ARG) = struct type combine_task = | CT_merge of node * node * explanation - module type THEORY = sig - type cc - type data - val key_id : int - val key : (term,lit,data) Key.t - val on_merge : cc -> N.t -> data -> N.t -> data -> Expl.t -> unit - val on_new_term: cc -> N.t -> term -> data option - end - type t = { tst: term_state; tbl: node T_tbl.t; @@ -326,8 +220,8 @@ module Make(A: ARG) = struct pending: node Vec.t; combine: combine_task Vec.t; undo: (unit -> unit) Backtrack_stack.t; - mutable theories: theory IM.t; - mutable on_merge: (t -> N.t -> N.t -> Expl.t -> unit) list; + mutable on_merge: ev_on_merge list; + mutable on_new_term: ev_on_new_term list; mutable ps_lits: lit list; (* TODO: thread it around instead? *) (* proof state *) ps_queue: (node*node) Vec.t; @@ -344,9 +238,8 @@ module Make(A: ARG) = struct several times. See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *) - and theory = (module THEORY with type cc = t) - - type cc = t + and ev_on_merge = t -> N.t -> th_data -> N.t -> th_data -> Expl.t -> unit + and ev_on_new_term = t -> N.t -> term -> th_data -> th_data option let[@inline] size_ (r:repr) = r.n_size let[@inline] true_ cc = Lazy.force cc.true_ @@ -359,10 +252,6 @@ module Make(A: ARG) = struct (* check if [t] is in the congruence closure. Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *) let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t - (* FIXME - let on_merge cc f = cc.on_merge <- f :: cc.on_merge - let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term - *) (* find representative, recursively *) let[@unroll 2] rec find_rec (n:node) : repr = @@ -408,7 +297,7 @@ module Make(A: ARG) = struct (* compute up-to-date signature *) let update_sig (s:signature) : Signature.t = - CC_types.map_view s + Congruence_closure_intf.map_view s ~f_f:(fun x->x) ~f_t:find_ ~f_ts:(List.map find_) @@ -475,7 +364,7 @@ module Make(A: ARG) = struct | FL_none -> 0 | FL_some {next=t'; _} -> 1 + distance_to_root t' - (* TODO: bool flag on nodes + stepwise progress + cleanup *) + (* TODO: new bool flag on nodes + stepwise progress + cleanup *) (* find the closest common ancestor of [a] and [b] in the proof forest *) let find_common_ancestor (a:node) (b:node) : node = let d_a = distance_to_root a in @@ -505,13 +394,11 @@ module Make(A: ARG) = struct let[@inline] ps_add_obligation (cc:t) a b = Vec.push cc.ps_queue (a,b) let[@inline] ps_add_lit ps l = ps.ps_lits <- l :: ps.ps_lits - (* TODO: remove *) let ps_clear (cc:t) = cc.ps_lits <- []; Vec.clear cc.ps_queue; () - (* TODO: turn this into a fold? *) (* decompose explanation [e] of why [n1 = n2] *) let rec decompose_explain cc (e:explanation) : unit = Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e); @@ -536,7 +423,14 @@ module Make(A: ARG) = struct end | E_lit lit -> ps_add_lit cc lit | E_merge (a,b) -> ps_add_obligation cc a b - | E_list l -> List.iter (decompose_explain cc) l + | E_merge_t (a,b) -> + (* find nodes for [a] and [b] on the fly *) + begin match T_tbl.find cc.tbl a, T_tbl.find cc.tbl b with + | a, b -> ps_add_obligation cc a b + | exception Not_found -> + Error.errorf "expl: cannot find node(s) for %a, %a" T.pp a T.pp b + end + | E_and (a,b) -> decompose_explain cc a; decompose_explain cc b (* explain why [a = parent_a], where [a -> ... -> parent_a] in the proof forest *) @@ -565,7 +459,6 @@ module Make(A: ARG) = struct done; cc.ps_lits - (* TODO: do not use ps_lits anymore *) let explain_eq_n ?(init=[]) cc (n1:node) (n2:node) : lit list = ps_clear cc; cc.ps_lits <- init; @@ -604,15 +497,15 @@ module Make(A: ARG) = struct push_pending cc n; ); (* initial theory data *) - let th_map = - IM.fold - (fun _ (module Th: THEORY with type cc=cc) th_map -> - match Th.on_new_term cc n t with - | None -> th_map - | Some v -> K_map.add Th.key v th_map) - cc.theories K_map.empty + let th_data = + List.fold_left + (fun data f -> + match f cc n t data with + | None -> data + | Some d -> d) + A.Data.empty cc.on_new_term in - n.n_th_data <- th_map; + n.n_th_data <- th_data; n (* compute the initial signature of the given node *) @@ -754,36 +647,19 @@ module Make(A: ARG) = struct merge_bool rb b ra a; (* perform [union r_from r_into] *) Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into); - (* call [on_merge] functions *) - List.iter (fun f -> f cc r_into r_from e_ab) cc.on_merge; - (* call micro theories *) + (* call [on_merge] functions, and merge theory data items *) begin let th_into = r_into.n_th_data in let th_from = r_from.n_th_data in - (* merge the two maps; if a key occurs in both, looks for theories with - this particular key *) - let th = - K_map.merge th_into th_from - ~f_both:(fun id pair_into pair_from -> - match IM.find id cc.theories with - | (module Th : THEORY with type cc=t) -> - (* casting magic *) - let (module K) = Th.key in - begin match pair_into, pair_from with - | K.Store v_into, K.Store v_from -> - Log.debugf 15 - (fun k->k "(@[cc.merge.th-on-merge@ :th %s@])" K.name); - (* FIXME: explanation is a=ra, e_ab, b=rb *) - Th.on_merge cc r_into v_into r_from v_from e_ab - | _ -> assert false - end - | exception Not_found -> ()) - in + let new_data = A.Data.merge th_into th_from in (* restore old data, if it changed *) - if not @@ K_map.equal th th_into then ( + if new_data != th_into then ( on_backtrack cc (fun () -> r_into.n_th_data <- th_into); ); - r_into.n_th_data <- th; + r_into.n_th_data <- new_data; + (* explanation is [a=ra & e_ab & b=rb] *) + let expl = Expl.mk_list [e_ab; Expl.mk_merge a ra; Expl.mk_merge b rb] in + List.iter (fun f -> f cc r_into th_into r_from th_from expl) cc.on_merge; end; begin (* parents might have a different signature, check for collisions *) @@ -864,8 +740,6 @@ module Make(A: ARG) = struct module Theory = struct type cc = t - type t = theory - type 'a key = (term,lit,'a) Key.t (* raise a conflict *) let raise_conflict cc expl = @@ -879,41 +753,6 @@ module Make(A: ARG) = struct merge_classes cc n1 n2 expl let add_term = add_term - - let get_data _cc n key = - assert (N.is_root n); - K_map.find key n.n_th_data - - (* FIXME: call micro theory here? in case of merge *) - (* update data for [n] *) - let add_data (type a) (self:cc) (n:node) (key: a key) (v:a) : unit = - let n = find_ n in - let map = n.n_th_data in - let old_v = K_map.find key map in - let v', is_diff = match old_v with - | None -> v, true - | Some old_v -> - let (module K) = key in - let v' = K.merge old_v v in - v', K.equal v v' - in - if is_diff then ( - on_backtrack self (fun () -> n.n_th_data <- map); - ); - n.n_th_data <- K_map.add key v' map; - () - - let make (type a) ~(key:a key) ~on_merge ~on_new_term () : t = - let module Th = struct - type nonrec cc = cc - type data = a - let key = key - let key_id = Key.id key - let on_merge = on_merge - let on_new_term = on_new_term - end in - (module Th : THEORY with type cc=cc) - end let check_invariants_ (cc:t) = @@ -1000,25 +839,18 @@ module Make(A: ARG) = struct let n2 = add_term cc t2 in merge_classes cc n1 n2 expl - let add_th (self:t) (th:theory) : unit = - let (module Th) = th in - if IM.mem Th.key_id self.theories then ( - Error.errorf "attempt to add two theories with key %a" Key.pp Th.key - ); - Log.debugf 3 (fun k->k "(@[@{cc.add-theory@} %a@])" Key.pp Th.key); - self.theories <- IM.add Th.key_id th self.theories - let on_merge cc f = cc.on_merge <- f :: cc.on_merge + let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term let create ?(stat=Stat.global) - ?th:(theories=[]) ?(on_merge=[]) ?(size=`Big) (tst:term_state) : t = + ?(on_merge=[]) ?(on_new_term=[]) ?(size=`Big) (tst:term_state) : t = let size = match size with `Small -> 128 | `Big -> 2048 in let rec cc = { tst; tbl = T_tbl.create size; signatures_tbl = Sig_tbl.create size; - theories=IM.empty; on_merge; + on_new_term; pending=Vec.create(); combine=Vec.create(); ps_lits=[]; @@ -1037,7 +869,6 @@ module Make(A: ARG) = struct in ignore (Lazy.force true_ : node); ignore (Lazy.force false_ : node); - List.iter (add_th cc) theories; (* now add theories *) cc let[@inline] find_t cc t : repr = diff --git a/src/cc/Congruence_closure.mli b/src/cc/Congruence_closure.mli index cc26fa17..717e327e 100644 --- a/src/cc/Congruence_closure.mli +++ b/src/cc/Congruence_closure.mli @@ -3,9 +3,6 @@ module type ARG = Congruence_closure_intf.ARG module type S = Congruence_closure_intf.S -module type THEORY_KEY = Congruence_closure_intf.THEORY_KEY -module Key : THEORY_KEY - module Make(A: ARG) : S with type term = A.Term.t and type lit = A.Lit.t @@ -13,4 +10,4 @@ module Make(A: ARG) and type term_state = A.Term.state and type proof = A.Proof.t and type model = A.Model.t - and module Key = Key + and type th_data = A.Data.t diff --git a/src/cc/Congruence_closure_intf.ml b/src/cc/Congruence_closure_intf.ml index 02aa2418..25e83f13 100644 --- a/src/cc/Congruence_closure_intf.ml +++ b/src/cc/Congruence_closure_intf.ml @@ -1,35 +1,124 @@ -module type ARG = CC_types.FULL +(** {1 Types used by the congruence closure} *) -module type THEORY_KEY = sig - type ('term,'lit,'a) t - (** An access key for theories which have per-class data ['a] *) +type ('f, 't, 'ts) view = + | Bool of bool + | App_fun of 'f * 'ts + | App_ho of 't * 'ts + | If of 't * 't * 't + | Eq of 't * 't + | Not of 't + | Opaque of 't (* do not enter *) - val create : - ?pp:'a Fmt.printer -> - name:string -> - eq:('a -> 'a -> bool) -> - merge:('a -> 'a -> 'a) -> - unit -> - ('term,'lit,'a) t - (** Generative creation of keys for the given theory data. +let[@inline] map_view ~f_f ~f_t ~f_ts (v:_ view) : _ view = + match v with + | Bool b -> Bool b + | App_fun (f, args) -> App_fun (f_f f, f_ts args) + | App_ho (f, args) -> App_ho (f_t f, f_ts args) + | Not t -> Not (f_t t) + | If (a,b,c) -> If (f_t a, f_t b, f_t c) + | Eq (a,b) -> Eq (f_t a, f_t b) + | Opaque t -> Opaque (f_t t) - @param eq : Equality. This is used to optimize backtracking info. +let iter_view ~f_f ~f_t ~f_ts (v:_ view) : unit = + match v with + | Bool _ -> () + | App_fun (f, args) -> f_f f; f_ts args + | App_ho (f, args) -> f_t f; f_ts args + | Not t -> f_t t + | If (a,b,c) -> f_t a; f_t b; f_t c; + | Eq (a,b) -> f_t a; f_t b + | Opaque t -> f_t t - @param merge : - [merge d1 d2] is called when merging classes with data [d1] and [d2] - respectively. The theory should already have checked that the merge - is compatible, and this produces the combined data for terms in the - merged class. - @param name name of the theory which owns this data - @param pp a printer for the data - *) +module type TERM = sig + module Fun : sig + type t + val equal : t -> t -> bool + val hash : t -> int + val pp : t Fmt.printer + end - val equal : ('t,'lit,_) t -> ('t,'lit,_) t -> bool - (** Checks if two keys are equal (generatively) *) + module Term : sig + type t + val equal : t -> t -> bool + val hash : t -> int + val pp : t Fmt.printer - val pp : _ t Fmt.printer - (** Prints the name of the key. *) + type state + + val bool : state -> bool -> t + + (** View the term through the lens of the congruence closure *) + val cc_view : t -> (Fun.t, t, t Iter.t) view + end +end + +module type TERM_LIT = sig + include TERM + + module Lit : sig + type t + val neg : t -> t + val equal : t -> t -> bool + val hash : t -> int + val pp : t Fmt.printer + + val sign : t -> bool + val term : t -> Term.t + end +end + +module type ARG = sig + include TERM_LIT + + module Proof : sig + type t + val pp : t Fmt.printer + + val default : t + (* TODO: to give more details + val cc_lemma : unit -> t + *) + end + + module Ty : sig + type t + + val equal : t -> t -> bool + val hash : t -> int + val pp : t Fmt.printer + end + + module Value : sig + type t + + val pp : t Fmt.printer + + val fresh : Term.t -> t + + val true_ : t + val false_ : t + end + + module Model : sig + type t + + val pp : t Fmt.printer + + val eval : t -> Term.t -> Value.t option + (** Evaluate the term in the current model *) + + val add : Term.t -> Value.t -> t -> t + end + + (** Monoid embedded in every node *) + module Data : sig + type t + + val empty : t + + val merge : t -> t -> t + end end module type S = sig @@ -39,9 +128,7 @@ module type S = sig type lit type proof type model - - (** Implementation of theory keys *) - module Key : THEORY_KEY + type th_data type t (** Global state of the congruence closure *) @@ -80,6 +167,15 @@ module type S = sig val iter_parents : t -> t Iter.t (** Traverse the parents of the class. Invariant: [is_root n] (see {!find} below) *) + + val th_data : t -> th_data + (** Access theory data for this node *) + + val get_field_usr1 : t -> bool + val set_field_usr1 : t -> bool -> unit + + val get_field_usr2 : t -> bool + val set_field_usr2 : t -> bool -> unit end module Expl : sig @@ -87,6 +183,7 @@ module type S = sig val pp : t Fmt.printer val mk_merge : N.t -> N.t -> t + val mk_merge_t : term -> term -> t val mk_lit : lit -> t val mk_list : t list -> t end @@ -117,9 +214,6 @@ module type S = sig module Theory : sig type cc = t - type t - - type 'a key = (term,lit,'a) Key.t val raise_conflict : cc -> Expl.t -> unit (** Raise a conflict with the given explanation @@ -134,39 +228,26 @@ module type S = sig val add_term : cc -> term -> N.t (** Add/retrieve node for this term. To be used in theories *) - - val get_data : cc -> N.t -> 'a key -> 'a option - (** Get data information for this particular representative *) - - val add_data : cc -> N.t -> 'a key -> 'a -> unit - (** Add data to this particular representative. Will be backtracked. *) - - val make : - key:'a key -> - on_merge:(cc -> N.t -> 'a -> N.t -> 'a -> Expl.t -> unit) -> - on_new_term:(cc -> N.t -> term -> 'a option) -> - unit -> - t - (** Build a micro theory. It can use the callbacks above. *) end + type ev_on_merge = t -> N.t -> th_data -> N.t -> th_data -> Expl.t -> unit + type ev_on_new_term = t -> N.t -> term -> th_data -> th_data option + val create : ?stat:Stat.t -> - ?th:Theory.t list -> - ?on_merge:(t -> N.t -> N.t -> Expl.t -> unit) list -> + ?on_merge:ev_on_merge list -> + ?on_new_term:ev_on_new_term list -> ?size:[`Small | `Big] -> term_state -> t (** Create a new congruence closure. *) - val add_th : t -> Theory.t -> unit - (** Add a (micro) theory to the congruence closure. - @raise Error.Error if there is already a theory with - the same key. *) - - val on_merge : t -> (t -> N.t -> N.t -> Expl.t -> unit) -> unit + val on_merge : t -> ev_on_merge -> unit (** Add a function to be called when two classes are merged *) + val on_new_term : t -> ev_on_new_term -> unit + (** Add a function to be called when a new node is created *) + val set_as_lit : t -> N.t -> lit -> unit (** map the given node to a literal. *) @@ -217,5 +298,4 @@ module type S = sig val check_invariants : t -> unit val pp_full : t Fmt.printer (**/**) - end diff --git a/src/cc/Mini_cc.ml b/src/cc/Mini_cc.ml index f968ba07..9b5a37dd 100644 --- a/src/cc/Mini_cc.ml +++ b/src/cc/Mini_cc.ml @@ -1,9 +1,11 @@ +open Congruence_closure_intf + type res = | Sat | Unsat -module type TERM = CC_types.TERM +module type TERM = Congruence_closure_intf.TERM module type S = sig type term @@ -18,12 +20,12 @@ module type S = sig val distinct : t -> term list -> unit val check : t -> res + + val classes : t -> term Iter.t Iter.t end module Make(A: TERM) = struct - open CC_types - module Fun = A.Fun module T = A.Term type fun_ = A.Fun.t @@ -47,6 +49,8 @@ module Make(A: TERM) = struct let[@inline] equal (n1:t) n2 = T.equal n1.n_t n2.n_t let[@inline] hash (n:t) = T.hash n.n_t let[@inline] size (n:t) = n.n_size + let[@inline] is_root n = n == n.n_root + let[@inline] term n = n.n_t let pp out n = T.pp out n.n_t let add_parent (self:t) ~p : unit = @@ -171,7 +175,7 @@ module Make(A: TERM) = struct (* find representative *) let[@inline] find_ (n:node) : node = let r = n.n_root in - assert (Node.equal r.n_root r); + assert (Node.is_root r); r let find_t_ (self:t) (t:term): node = @@ -313,4 +317,10 @@ module Make(A: TERM) = struct self.ok <- false; Unsat + let classes self : _ Iter.t = + T_tbl.values self.tbl + |> Iter.filter Node.is_root + |> Iter.map + (fun n -> Node.iter_cls n |> Iter.map Node.term) + end diff --git a/src/cc/Mini_cc.mli b/src/cc/Mini_cc.mli index b460c74a..6f96c723 100644 --- a/src/cc/Mini_cc.mli +++ b/src/cc/Mini_cc.mli @@ -6,11 +6,13 @@ It just decides the satisfiability of a set of (dis)equations. *) +open Congruence_closure_intf + type res = | Sat | Unsat -module type TERM = CC_types.TERM +module type TERM = Congruence_closure_intf.TERM module type S = sig type term @@ -28,6 +30,10 @@ module type S = sig (** [distinct cc l] asserts that all terms in [l] are distinct *) val check : t -> res + + val classes : t -> term Iter.t Iter.t + (** Traverse the set of classes in the congruence closure. + This should be called only if {!check} returned [Sat]. *) end module Make(A: TERM) diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index 574fe5bd..f0bee810 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -1,5 +1,5 @@ -type ('f, 't, 'ts) view = ('f, 't, 'ts) CC_types.view = +type ('f, 't, 'ts) view = ('f, 't, 'ts) Congruence_closure_intf.view = | Bool of bool | App_fun of 'f * 'ts | App_ho of 't * 'ts @@ -8,16 +8,12 @@ type ('f, 't, 'ts) view = ('f, 't, 'ts) CC_types.view = | Not of 't | Opaque of 't (* do not enter *) -module CC_types = CC_types - (** Parameter for the congruence closure *) -module type TERM_LIT = CC_types.TERM_LIT -module type FULL = CC_types.FULL +module type TERM_LIT = Congruence_closure_intf.TERM_LIT +module type ARG = Congruence_closure_intf.ARG module type S = Congruence_closure.S module Mini_cc = Mini_cc module Congruence_closure = Congruence_closure -module Key = Congruence_closure.Key - module Make = Congruence_closure.Make