diff --git a/src/cc/Congruence_closure.ml b/src/cc/Congruence_closure.ml index a7656a16..5af965fd 100644 --- a/src/cc/Congruence_closure.ml +++ b/src/cc/Congruence_closure.ml @@ -3,6 +3,42 @@ open CC_types module type ARG = Congruence_closure_intf.ARG module type S = Congruence_closure_intf.S +module type THEORY_DATA = Congruence_closure_intf.THEORY_DATA +module type THEORY_KEY = Congruence_closure_intf.THEORY_KEY + +type ('t, 'a) theory_data = ('t,'a) Congruence_closure_intf.theory_data + +module type KEY_IMPL = sig + include THEORY_DATA + exception Store of t + val id : int +end + +(** Custom keys for theory data. + This imitates the classic tricks for heterogeneous maps + https://blog.janestreet.com/a-universal-type/ + *) +module Key = struct + type ('term, 'a) t = (module KEY_IMPL with type term = 'term and type t = 'a) + + let n_ = ref 0 + + let create (type term)(type d) (th:(term,d) theory_data) : (term,d) t = + let (module TH) = th in + let module K = struct + include TH + exception Store of d + let id = !n_ + end in + incr n_; + (module K) + + let id (module K : KEY_IMPL) = K.id + + let equal + : type a b term. (term,a) t -> (term,b) t -> bool + = fun (module K1) (module K2) -> K1.id = K2.id +end module Bits = CCBitField.Make() @@ -30,6 +66,62 @@ module Make(A: ARG) = struct module T = A.Term module Fun = A.Fun + module Key = Key + + + (** Map for theory data associated with representatives *) + module K_map = struct + type pair = Pair : (term, 'a) Key.t * exn -> pair + module IM = Map.Make(CCInt) + + type t = pair IM.t + + let empty = IM.empty + + let[@inline] mem k t = IM.mem (Key.id k) t + + let is_empty = IM.is_empty + + let find (type a) (k : (term,a) Key.t) (self:t) : a option = + let (module K) = k in + match IM.find K.id self with + | Pair (_, K.Store v) -> Some v + | _ -> None + | exception Not_found -> None + + let add (type a) (k : (term,a) Key.t) (v:a) (self:t) : t = + let (module K) = k in + IM.add K.id (Pair (k, K.Store v)) self + + let remove (type a) (k: (term,a) Key.t) self : t = + let (module K) = k in + IM.remove K.id self + + let equal (m1:t) (m2:t) : bool = + IM.equal + (fun p1 p2 -> + let Pair ((module K1), v1) = p1 in + let Pair ((module K2), v2) = p2 in + K1.id = K2.id && + match v1, v2 with K1.Store v1, K1.Store v2 -> K1.equal v1 v2 | _ -> false) + m1 m2 + + let merge (m1:t) (m2:t) : t = + IM.merge + (fun _ p1 p2 -> + match p1, p2 with + | None, None -> None + | Some v, None + | None, Some v -> Some v + | Some (Pair ((module K1) as key1, v1)), Some (Pair (_, v2)) -> + match v1, v2 with + | K1.Store v1, K1.Store v2 -> + (* merge content *) + Some (Pair (key1, K1.Store (K1.merge v1 v2))) + | _ -> assert false + ) + m1 m2 + end (** A node of the congruence closure. An equivalence class is represented by its "root" element, @@ -44,6 +136,7 @@ module Make(A: ARG) = struct mutable n_size: int; (* size of the class *) mutable n_as_lit: lit option; (* TODO: put into payload? and only in root? *) mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *) + mutable n_th_data: K_map.t; (* theory data *) (* TODO: make a micro theory and move this inside *) mutable n_tags: (node * explanation) Util.Int_map.t; (* "distinct" tags (i.e. set of `(distinct t1…tn)` terms this belongs to *) @@ -58,6 +151,7 @@ module Make(A: ARG) = struct expl: explanation; } + (* TODO: make this recursive (the list case) *) (* atomic explanation in the congruence closure *) and explanation = | E_reduction (* by pure reduction, tautologically equal *) @@ -66,7 +160,6 @@ module Make(A: ARG) = struct | E_congruence of node * node (* caused by normal congruence *) | E_lit of lit (* because of this literal *) | E_lits of lit list (* because of this (true) conjunction *) - (* TODO: congruence case (cheaper than "merges") *) type repr = node type conflict = lit list @@ -91,6 +184,7 @@ module Make(A: ARG) = struct n_expl=FL_none; n_next=n; n_size=1; + n_th_data=K_map.empty; n_tags=Util.Int_map.empty; } in n @@ -213,7 +307,8 @@ module Make(A: ARG) = struct pending: node Vec.t; combine: combine_task Vec.t; undo: (unit -> unit) Backtrack_stack.t; - on_merge: (repr -> repr -> explanation -> unit) option; + mutable on_merge: (t -> repr -> repr -> explanation -> unit) list; + mutable on_new_term: (t -> repr -> term -> unit) list; mutable ps_lits: lit list; (* TODO: thread it around instead? *) (* proof state *) ps_queue: (node*node) Vec.t; @@ -230,6 +325,7 @@ module Make(A: ARG) = struct let[@inline] size_ (r:repr) = r.n_size let[@inline] true_ cc = Lazy.force cc.true_ let[@inline] false_ cc = Lazy.force cc.false_ + let[@inline] term_state cc = cc.tst let[@inline] on_backtrack cc f : unit = Backtrack_stack.push_if_nonzero_level cc.undo f @@ -237,6 +333,8 @@ module Make(A: ARG) = struct (* check if [t] is in the congruence closure. Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *) let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t + let on_merge cc f = cc.on_merge <- f :: cc.on_merge + let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term (* find representative, recursively *) let[@unroll 2] rec find_rec (n:node) : repr = @@ -280,6 +378,28 @@ module Make(A: ARG) = struct (Util.pp_seq ~sep:" " pp_n) (T_tbl.values cc.tbl) (Util.pp_seq ~sep:" " pp_sig_e) (Sig_tbl.to_seq cc.signatures_tbl) + let th_data_get (_:t) (n:node) (key: _ Key.t) : _ option = + let n = find_ n in + K_map.find key n.n_th_data + + (* update data for [n] *) + let th_data_add (type a) (self:t) (n:node) (key: (term,a) Key.t) (v:a) : unit = + let n = find_ n in + let map = n.n_th_data in + let old_v = K_map.find key map in + let v', is_diff = match old_v with + | None -> v, true + | Some old_v -> + let (module K) = key in + let v' = K.merge old_v v in + v', K.equal v v' + in + if is_diff then ( + on_backtrack self (fun () -> n.n_th_data <- map); + ); + n.n_th_data <- K_map.add key v' map; + () + (* compute up-to-date signature *) let update_sig (s:signature) : Signature.t = CC_types.map_view s @@ -311,7 +431,7 @@ module Make(A: ARG) = struct Vec.push cc.pending t ) - let push_combine cc t u e : unit = + let merge_classes cc t u e : unit = Log.debugf 5 (fun k->k "(@[cc.push_combine@ %a ~@ %a@ :expl %a@])" N.pp t N.pp u Expl.pp e); @@ -491,6 +611,7 @@ module Make(A: ARG) = struct (* [n] might be merged with other equiv classes *) push_pending cc n; ); + List.iter (fun f -> f cc n t) cc.on_new_term; n (* compute the initial signature of the given node *) @@ -526,7 +647,6 @@ module Make(A: ARG) = struct return @@ If (deref_sub a, deref_sub b, deref_sub c) let[@inline] add_term cc t : node = add_term_rec_ cc t - let[@inline] add_term' cc t : unit = ignore (add_term_rec_ cc t : node) let set_as_lit cc (n:node) (lit:lit) : unit = match n.n_as_lit with @@ -536,15 +656,6 @@ module Make(A: ARG) = struct on_backtrack cc (fun () -> n.n_as_lit <- None); n.n_as_lit <- Some lit - (* Checks if [ra] and [~into] have compatible normal forms and can - be merged w.r.t. the theories. - Side effect: also pushes sub-tasks *) - let notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit = - assert (N.is_root rb); - match cc.on_merge with - | Some f -> f ra rb e - | None -> () - let[@inline] n_is_bool (self:t) n : bool = N.equal n (true_ self) || N.equal n (false_ self) @@ -569,7 +680,7 @@ module Make(A: ARG) = struct (* if [a=b] is now true, merge [(a=b)] and [true] *) if same_class a b then ( let expl = Expl.mk_merge a b in - push_combine cc n (true_ cc) expl + merge_classes cc n (true_ cc) expl ) | Some s0 -> (* update the signature by using [find] on each sub-node *) @@ -584,7 +695,7 @@ module Make(A: ARG) = struct arguments that are pairwise equal *) assert (n != u); let expl = Expl.mk_congruence n u in - push_combine cc n u expl + merge_classes cc n u expl end (* TODO: remove, once we have moved distinct to a theory *) @@ -700,7 +811,7 @@ module Make(A: ARG) = struct a.n_expl <- FL_some {next=b; expl=e_ab}; end; (* notify listeners of the merge *) - notify_merge cc r_from ~into:r_into e_ab; + List.iter (fun f -> f cc r_into r_from e_ab) cc.on_merge ) and task_distinct_ cc acts (l:node list) tag expl : unit = @@ -816,7 +927,7 @@ module Make(A: ARG) = struct let a = add_term cc a in let b = add_term cc b in (* merge [a] and [b] *) - push_combine cc a b (Expl.mk_lit lit) + merge_classes cc a b (Expl.mk_lit lit) | _ -> (* equate t and true/false *) let rhs = if sign then true_ cc else false_ cc in @@ -825,7 +936,7 @@ module Make(A: ARG) = struct basically, just have [n] point to true/false and thus acquire the corresponding value, so its superterms (like [ite]) can evaluate properly *) - push_combine cc n rhs (Expl.mk_lit lit) + merge_classes cc n rhs (Expl.mk_lit lit) end let[@inline] assert_lits cc lits : unit = @@ -835,7 +946,7 @@ module Make(A: ARG) = struct let expl = Expl.mk_lits e in let n1 = add_term cc t1 in let n2 = add_term cc t2 in - push_combine cc n1 n2 expl + merge_classes cc n1 n2 expl (* generative tag used to annotate classes that can't be merged *) let distinct_tag_ = ref 0 @@ -848,13 +959,13 @@ module Make(A: ARG) = struct let l = List.map (add_term cc) l in Vec.push cc.combine @@ CT_distinct (l, tag, Expl.mk_lit lit) - let create ?on_merge ?(size=`Big) (tst:term_state) : t = + let create ?(on_merge=[]) ?(on_new_term=[]) ?(size=`Big) (tst:term_state) : t = let size = match size with `Small -> 128 | `Big -> 2048 in let rec cc = { tst; tbl = T_tbl.create size; signatures_tbl = Sig_tbl.create size; - on_merge; + on_merge; on_new_term; pending=Vec.create(); combine=Vec.create(); ps_lits=[]; diff --git a/src/cc/Congruence_closure.mli b/src/cc/Congruence_closure.mli index e2629ad5..31982030 100644 --- a/src/cc/Congruence_closure.mli +++ b/src/cc/Congruence_closure.mli @@ -2,6 +2,12 @@ module type ARG = Congruence_closure_intf.ARG module type S = Congruence_closure_intf.S +module type THEORY_DATA = Congruence_closure_intf.THEORY_DATA +module type THEORY_KEY = Congruence_closure_intf.THEORY_KEY + +type ('t, 'a) theory_data = ('t,'a) Congruence_closure_intf.theory_data + +module Key : THEORY_KEY module Make(A: ARG) : S with type term = A.Term.t @@ -10,3 +16,4 @@ module Make(A: ARG) and type term_state = A.Term.state and type proof = A.Proof.t and type model = A.Model.t + and module Key = Key diff --git a/src/cc/Congruence_closure_intf.ml b/src/cc/Congruence_closure_intf.ml index a4aad927..76c74fc9 100644 --- a/src/cc/Congruence_closure_intf.ml +++ b/src/cc/Congruence_closure_intf.ml @@ -1,7 +1,39 @@ module type ARG = CC_types.FULL -module type S0 = sig +(** Data stored by a theory for its own terms. + + It needs to form a commutative monoid where values can be unmerged upon + backtracking. +*) +module type THEORY_DATA = sig + type term + type t + + val empty : t + + val equal : t -> t -> bool + (** Equality. This is used to optimize backtracking info. *) + + val merge : t -> t -> t + (** [merge d1 d2] is called when merging classes with data [d1] and [d2] + respectively. The theory should already have checked that the merge + is compatible, and this produces the combined data for terms in the + merged class. *) +end + +type ('t, 'a) theory_data = (module THEORY_DATA with type term = 't and type t = 'a) + +module type THEORY_KEY = sig + type ('t, 'a) t + (** An access key for theories that use terms ['t] and which have + per-class data ['a] *) + + val create : ('t, 'a) theory_data -> ('t, 'a) t + (** Generative creation of keys for the given theory data. *) +end + +module type S = sig type term_state type term type fun_ @@ -9,13 +41,12 @@ module type S0 = sig type proof type model - (** Actions available to the theory *) - type sat_actions = (Msat.void, lit, Msat.void, proof) Msat.acts + (** Implementation of theory keys *) + module Key : THEORY_KEY type t (** Global state of the congruence closure *) - (** An equivalence class is a set of terms that are currently equal in the partial model built by the solver. The class is represented by a collection of nodes, one of which is @@ -74,18 +105,9 @@ module type S0 = sig type conflict = lit list - (* TODO: notion of micro theory, parametrized by [on_backtrack, find, etc] - and with callbacks for on_merge? *) + (** Accessors *) - (* TODO micro theories as parameters *) - val create : - ?on_merge:(repr -> repr -> explanation -> unit) -> - ?size:[`Small | `Big] -> - term_state -> - t - (** Create a new congruence closure. - @param on_merge callback to be called on every merge - *) + val term_state : t -> term_state val find : t -> node -> repr (** Current representative *) @@ -94,12 +116,41 @@ module type S0 = sig (** Add the term to the congruence closure, if not present already. Will be backtracked. *) + (** Actions available to the theory *) + type sat_actions = (Msat.void, lit, Msat.void, proof) Msat.acts + + val create : + ?on_merge:(t -> repr -> repr -> explanation -> unit) list -> + ?on_new_term:(t -> repr -> term -> unit) list -> + ?size:[`Small | `Big] -> + term_state -> + t + (** Create a new congruence closure. *) + + val on_merge : t -> (t -> repr -> repr -> explanation -> unit) -> unit + (** Add a callback, to be called whenever two classes are merged *) + + val on_new_term : t -> (t -> repr -> term -> unit) -> unit + (** Add a callback, to be called whenever a node is added *) + + val merge_classes : t -> node -> node -> explanation -> unit + (** Merge the two given nodes with given explanation. + It must be a theory tautology that [expl ==> n1 = n2] *) + + val th_data_get : t -> N.t -> (term, 'a) Key.t -> 'a option + (** Get data information for this particular representative *) + + val th_data_add : t -> N.t -> (term, 'a) Key.t -> 'a -> unit + (** Add the given data to this node (or rather, to its representative). + This will be backtracked. *) + + (* TODO: merge true/false? + val raise_conflict : CC.t -> CC.N.t -> CC.N.t -> Expl.t -> 'a + *) + val set_as_lit : t -> N.t -> lit -> unit (** map the given node to a literal. *) - val add_term' : t -> term -> unit - (** Same as {!add_term} but ignore the result *) - val find_t : t -> term -> repr (** Current representative of the term. @raise Not_found if the term is not already {!add}-ed. *) @@ -108,17 +159,21 @@ module type S0 = sig (** Add a sequence of terms to the congruence closure *) val all_classes : t -> repr Sequence.t - (** All current classes *) + (** All current classes. This is costly, only use if there is no other solution *) val assert_lit : t -> lit -> unit (** Given a literal, assume it in the congruence closure and propagate - its consequences. Will be backtracked. *) + its consequences. Will be backtracked. + + Useful for the theory combination or the SAT solver's functor *) val assert_lits : t -> lit Sequence.t -> unit + (** Addition of many literals *) val assert_eq : t -> term -> term -> lit list -> unit (** merge the given terms with some explanations *) + (* TODO: remove and move into its own library as a micro theory *) val assert_distinct : t -> term list -> neq:term -> lit -> unit (** [assert_distinct l ~neq:u e] asserts all elements of [l] are distinct because [lit] is true @@ -129,8 +184,10 @@ module type S0 = sig Will use the [sat_actions] to propagate literals, declare conflicts, etc. *) val push_level : t -> unit + (** Push backtracking level *) val pop_levels : t -> int -> unit + (** Restore to state [n] calls to [push_level] earlier. Used during backtracking. *) val mk_model : t -> model -> model (** Enrich a model by mapping terms to their representative's value, @@ -140,11 +197,5 @@ module type S0 = sig val check_invariants : t -> unit val pp_full : t Fmt.printer (**/**) -end - -module type S = sig - - include S0 - end diff --git a/src/smt/Theory.ml b/src/smt/Theory.ml index 71893f93..4c440366 100644 --- a/src/smt/Theory.ml +++ b/src/smt/Theory.ml @@ -71,8 +71,10 @@ module type S = sig val create : Term.state -> t (** Instantiate the theory's state *) + (* TODO: instead pass Congruence_closure.theory to [create] val on_merge: t -> actions -> CC_eq_class.t -> CC_eq_class.t -> CC_expl.t -> unit (** Called when two classes are merged *) + *) val partial_check : t -> actions -> Lit.t Sequence.t -> unit (** Called when a literal becomes true *) diff --git a/src/smt/Theory_combine.ml b/src/smt/Theory_combine.ml index 991dfa46..a9e306ec 100644 --- a/src/smt/Theory_combine.ml +++ b/src/smt/Theory_combine.ml @@ -29,7 +29,6 @@ type t = { (** congruence closure *) mutable theories : theory_state list; (** Set of theories *) - new_merges: (Eq_class.t * Eq_class.t * Expl.t) Vec.t; } let[@inline] cc (t:t) = Lazy.force t.cc @@ -45,7 +44,6 @@ let assert_lits_ ~final (self:t) acts (lits:Lit.t Sequence.t) : unit = (fun k->k "(@[@{th_combine.assume_lits@}%s@ %a@])" (if final then "[final]" else "") (Util.pp_seq ~sep:"; " Lit.pp) lits); (* transmit to CC *) - Vec.clear self.new_merges; let cc = cc self in if not final then ( CC.assert_lits cc lits; @@ -71,7 +69,6 @@ let assert_lits_ ~final (self:t) acts (lits:Lit.t Sequence.t) : unit = theories self (fun (Th_state ((module Th),st)) -> (* give new merges, then call {final,partial}-check *) - Vec.iter (fun (r1,r2,e) -> Th.on_merge st acts r1 r2 e) self.new_merges; if final then Th.final_check st acts lits else Th.partial_check st acts lits); () @@ -123,10 +120,6 @@ let mk_model (self:t) lits : Model.t = (** {2 Interface to Congruence Closure} *) -(* when CC decided to merge [r1] and [r2], notify theories *) -let[@inline] on_merge_from_cc (self:t) r1 r2 e : unit = - Vec.push self.new_merges (r1,r2,e) - (** {2 Main} *) (* create a new theory combination *) @@ -134,11 +127,10 @@ let create () : t = Log.debug 5 "th_combine.create"; let rec self = { tst=Term.create ~size:1024 (); - new_merges=Vec.create(); cc = lazy ( (* lazily tie the knot *) - let on_merge = on_merge_from_cc self in - CC.create ~on_merge ~size:`Big self.tst; + (* TODO: pass theories *) + CC.create ~size:`Big self.tst; ); theories = []; } in