wip: micro theories

This commit is contained in:
Simon Cruanes 2019-02-22 20:57:17 -06:00
parent 77a5475862
commit c79a5a4798
5 changed files with 220 additions and 57 deletions

View file

@ -3,6 +3,42 @@ open CC_types
module type ARG = Congruence_closure_intf.ARG
module type S = Congruence_closure_intf.S
module type THEORY_DATA = Congruence_closure_intf.THEORY_DATA
module type THEORY_KEY = Congruence_closure_intf.THEORY_KEY
type ('t, 'a) theory_data = ('t,'a) Congruence_closure_intf.theory_data
module type KEY_IMPL = sig
include THEORY_DATA
exception Store of t
val id : int
end
(** Custom keys for theory data.
This imitates the classic tricks for heterogeneous maps
https://blog.janestreet.com/a-universal-type/
*)
module Key = struct
type ('term, 'a) t = (module KEY_IMPL with type term = 'term and type t = 'a)
let n_ = ref 0
let create (type term)(type d) (th:(term,d) theory_data) : (term,d) t =
let (module TH) = th in
let module K = struct
include TH
exception Store of d
let id = !n_
end in
incr n_;
(module K)
let id (module K : KEY_IMPL) = K.id
let equal
: type a b term. (term,a) t -> (term,b) t -> bool
= fun (module K1) (module K2) -> K1.id = K2.id
end
module Bits = CCBitField.Make()
@ -30,6 +66,62 @@ module Make(A: ARG) = struct
module T = A.Term
module Fun = A.Fun
module Key = Key
(** Map for theory data associated with representatives *)
module K_map = struct
type pair = Pair : (term, 'a) Key.t * exn -> pair
module IM = Map.Make(CCInt)
type t = pair IM.t
let empty = IM.empty
let[@inline] mem k t = IM.mem (Key.id k) t
let is_empty = IM.is_empty
let find (type a) (k : (term,a) Key.t) (self:t) : a option =
let (module K) = k in
match IM.find K.id self with
| Pair (_, K.Store v) -> Some v
| _ -> None
| exception Not_found -> None
let add (type a) (k : (term,a) Key.t) (v:a) (self:t) : t =
let (module K) = k in
IM.add K.id (Pair (k, K.Store v)) self
let remove (type a) (k: (term,a) Key.t) self : t =
let (module K) = k in
IM.remove K.id self
let equal (m1:t) (m2:t) : bool =
IM.equal
(fun p1 p2 ->
let Pair ((module K1), v1) = p1 in
let Pair ((module K2), v2) = p2 in
K1.id = K2.id &&
match v1, v2 with K1.Store v1, K1.Store v2 -> K1.equal v1 v2 | _ -> false)
m1 m2
let merge (m1:t) (m2:t) : t =
IM.merge
(fun _ p1 p2 ->
match p1, p2 with
| None, None -> None
| Some v, None
| None, Some v -> Some v
| Some (Pair ((module K1) as key1, v1)), Some (Pair (_, v2)) ->
match v1, v2 with
| K1.Store v1, K1.Store v2 ->
(* merge content *)
Some (Pair (key1, K1.Store (K1.merge v1 v2)))
| _ -> assert false
)
m1 m2
end
(** A node of the congruence closure.
An equivalence class is represented by its "root" element,
@ -44,6 +136,7 @@ module Make(A: ARG) = struct
mutable n_size: int; (* size of the class *)
mutable n_as_lit: lit option; (* TODO: put into payload? and only in root? *)
mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *)
mutable n_th_data: K_map.t; (* theory data *)
(* TODO: make a micro theory and move this inside *)
mutable n_tags: (node * explanation) Util.Int_map.t;
(* "distinct" tags (i.e. set of `(distinct t1…tn)` terms this belongs to *)
@ -58,6 +151,7 @@ module Make(A: ARG) = struct
expl: explanation;
}
(* TODO: make this recursive (the list case) *)
(* atomic explanation in the congruence closure *)
and explanation =
| E_reduction (* by pure reduction, tautologically equal *)
@ -66,7 +160,6 @@ module Make(A: ARG) = struct
| E_congruence of node * node (* caused by normal congruence *)
| E_lit of lit (* because of this literal *)
| E_lits of lit list (* because of this (true) conjunction *)
(* TODO: congruence case (cheaper than "merges") *)
type repr = node
type conflict = lit list
@ -91,6 +184,7 @@ module Make(A: ARG) = struct
n_expl=FL_none;
n_next=n;
n_size=1;
n_th_data=K_map.empty;
n_tags=Util.Int_map.empty;
} in
n
@ -213,7 +307,8 @@ module Make(A: ARG) = struct
pending: node Vec.t;
combine: combine_task Vec.t;
undo: (unit -> unit) Backtrack_stack.t;
on_merge: (repr -> repr -> explanation -> unit) option;
mutable on_merge: (t -> repr -> repr -> explanation -> unit) list;
mutable on_new_term: (t -> repr -> term -> unit) list;
mutable ps_lits: lit list; (* TODO: thread it around instead? *)
(* proof state *)
ps_queue: (node*node) Vec.t;
@ -230,6 +325,7 @@ module Make(A: ARG) = struct
let[@inline] size_ (r:repr) = r.n_size
let[@inline] true_ cc = Lazy.force cc.true_
let[@inline] false_ cc = Lazy.force cc.false_
let[@inline] term_state cc = cc.tst
let[@inline] on_backtrack cc f : unit =
Backtrack_stack.push_if_nonzero_level cc.undo f
@ -237,6 +333,8 @@ module Make(A: ARG) = struct
(* check if [t] is in the congruence closure.
Invariant: [in_cc t do_cc t => forall u subterm t, in_cc u] *)
let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t
let on_merge cc f = cc.on_merge <- f :: cc.on_merge
let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term
(* find representative, recursively *)
let[@unroll 2] rec find_rec (n:node) : repr =
@ -280,6 +378,28 @@ module Make(A: ARG) = struct
(Util.pp_seq ~sep:" " pp_n) (T_tbl.values cc.tbl)
(Util.pp_seq ~sep:" " pp_sig_e) (Sig_tbl.to_seq cc.signatures_tbl)
let th_data_get (_:t) (n:node) (key: _ Key.t) : _ option =
let n = find_ n in
K_map.find key n.n_th_data
(* update data for [n] *)
let th_data_add (type a) (self:t) (n:node) (key: (term,a) Key.t) (v:a) : unit =
let n = find_ n in
let map = n.n_th_data in
let old_v = K_map.find key map in
let v', is_diff = match old_v with
| None -> v, true
| Some old_v ->
let (module K) = key in
let v' = K.merge old_v v in
v', K.equal v v'
in
if is_diff then (
on_backtrack self (fun () -> n.n_th_data <- map);
);
n.n_th_data <- K_map.add key v' map;
()
(* compute up-to-date signature *)
let update_sig (s:signature) : Signature.t =
CC_types.map_view s
@ -311,7 +431,7 @@ module Make(A: ARG) = struct
Vec.push cc.pending t
)
let push_combine cc t u e : unit =
let merge_classes cc t u e : unit =
Log.debugf 5
(fun k->k "(@[<hv1>cc.push_combine@ %a ~@ %a@ :expl %a@])"
N.pp t N.pp u Expl.pp e);
@ -491,6 +611,7 @@ module Make(A: ARG) = struct
(* [n] might be merged with other equiv classes *)
push_pending cc n;
);
List.iter (fun f -> f cc n t) cc.on_new_term;
n
(* compute the initial signature of the given node *)
@ -526,7 +647,6 @@ module Make(A: ARG) = struct
return @@ If (deref_sub a, deref_sub b, deref_sub c)
let[@inline] add_term cc t : node = add_term_rec_ cc t
let[@inline] add_term' cc t : unit = ignore (add_term_rec_ cc t : node)
let set_as_lit cc (n:node) (lit:lit) : unit =
match n.n_as_lit with
@ -536,15 +656,6 @@ module Make(A: ARG) = struct
on_backtrack cc (fun () -> n.n_as_lit <- None);
n.n_as_lit <- Some lit
(* Checks if [ra] and [~into] have compatible normal forms and can
be merged w.r.t. the theories.
Side effect: also pushes sub-tasks *)
let notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit =
assert (N.is_root rb);
match cc.on_merge with
| Some f -> f ra rb e
| None -> ()
let[@inline] n_is_bool (self:t) n : bool =
N.equal n (true_ self) || N.equal n (false_ self)
@ -569,7 +680,7 @@ module Make(A: ARG) = struct
(* if [a=b] is now true, merge [(a=b)] and [true] *)
if same_class a b then (
let expl = Expl.mk_merge a b in
push_combine cc n (true_ cc) expl
merge_classes cc n (true_ cc) expl
)
| Some s0 ->
(* update the signature by using [find] on each sub-node *)
@ -584,7 +695,7 @@ module Make(A: ARG) = struct
arguments that are pairwise equal *)
assert (n != u);
let expl = Expl.mk_congruence n u in
push_combine cc n u expl
merge_classes cc n u expl
end
(* TODO: remove, once we have moved distinct to a theory *)
@ -700,7 +811,7 @@ module Make(A: ARG) = struct
a.n_expl <- FL_some {next=b; expl=e_ab};
end;
(* notify listeners of the merge *)
notify_merge cc r_from ~into:r_into e_ab;
List.iter (fun f -> f cc r_into r_from e_ab) cc.on_merge
)
and task_distinct_ cc acts (l:node list) tag expl : unit =
@ -816,7 +927,7 @@ module Make(A: ARG) = struct
let a = add_term cc a in
let b = add_term cc b in
(* merge [a] and [b] *)
push_combine cc a b (Expl.mk_lit lit)
merge_classes cc a b (Expl.mk_lit lit)
| _ ->
(* equate t and true/false *)
let rhs = if sign then true_ cc else false_ cc in
@ -825,7 +936,7 @@ module Make(A: ARG) = struct
basically, just have [n] point to true/false and thus acquire
the corresponding value, so its superterms (like [ite]) can evaluate
properly *)
push_combine cc n rhs (Expl.mk_lit lit)
merge_classes cc n rhs (Expl.mk_lit lit)
end
let[@inline] assert_lits cc lits : unit =
@ -835,7 +946,7 @@ module Make(A: ARG) = struct
let expl = Expl.mk_lits e in
let n1 = add_term cc t1 in
let n2 = add_term cc t2 in
push_combine cc n1 n2 expl
merge_classes cc n1 n2 expl
(* generative tag used to annotate classes that can't be merged *)
let distinct_tag_ = ref 0
@ -848,13 +959,13 @@ module Make(A: ARG) = struct
let l = List.map (add_term cc) l in
Vec.push cc.combine @@ CT_distinct (l, tag, Expl.mk_lit lit)
let create ?on_merge ?(size=`Big) (tst:term_state) : t =
let create ?(on_merge=[]) ?(on_new_term=[]) ?(size=`Big) (tst:term_state) : t =
let size = match size with `Small -> 128 | `Big -> 2048 in
let rec cc = {
tst;
tbl = T_tbl.create size;
signatures_tbl = Sig_tbl.create size;
on_merge;
on_merge; on_new_term;
pending=Vec.create();
combine=Vec.create();
ps_lits=[];

View file

@ -2,6 +2,12 @@
module type ARG = Congruence_closure_intf.ARG
module type S = Congruence_closure_intf.S
module type THEORY_DATA = Congruence_closure_intf.THEORY_DATA
module type THEORY_KEY = Congruence_closure_intf.THEORY_KEY
type ('t, 'a) theory_data = ('t,'a) Congruence_closure_intf.theory_data
module Key : THEORY_KEY
module Make(A: ARG)
: S with type term = A.Term.t
@ -10,3 +16,4 @@ module Make(A: ARG)
and type term_state = A.Term.state
and type proof = A.Proof.t
and type model = A.Model.t
and module Key = Key

View file

@ -1,7 +1,39 @@
module type ARG = CC_types.FULL
module type S0 = sig
(** Data stored by a theory for its own terms.
It needs to form a commutative monoid where values can be unmerged upon
backtracking.
*)
module type THEORY_DATA = sig
type term
type t
val empty : t
val equal : t -> t -> bool
(** Equality. This is used to optimize backtracking info. *)
val merge : t -> t -> t
(** [merge d1 d2] is called when merging classes with data [d1] and [d2]
respectively. The theory should already have checked that the merge
is compatible, and this produces the combined data for terms in the
merged class. *)
end
type ('t, 'a) theory_data = (module THEORY_DATA with type term = 't and type t = 'a)
module type THEORY_KEY = sig
type ('t, 'a) t
(** An access key for theories that use terms ['t] and which have
per-class data ['a] *)
val create : ('t, 'a) theory_data -> ('t, 'a) t
(** Generative creation of keys for the given theory data. *)
end
module type S = sig
type term_state
type term
type fun_
@ -9,13 +41,12 @@ module type S0 = sig
type proof
type model
(** Actions available to the theory *)
type sat_actions = (Msat.void, lit, Msat.void, proof) Msat.acts
(** Implementation of theory keys *)
module Key : THEORY_KEY
type t
(** Global state of the congruence closure *)
(** An equivalence class is a set of terms that are currently equal
in the partial model built by the solver.
The class is represented by a collection of nodes, one of which is
@ -74,18 +105,9 @@ module type S0 = sig
type conflict = lit list
(* TODO: notion of micro theory, parametrized by [on_backtrack, find, etc]
and with callbacks for on_merge? *)
(** Accessors *)
(* TODO micro theories as parameters *)
val create :
?on_merge:(repr -> repr -> explanation -> unit) ->
?size:[`Small | `Big] ->
term_state ->
t
(** Create a new congruence closure.
@param on_merge callback to be called on every merge
*)
val term_state : t -> term_state
val find : t -> node -> repr
(** Current representative *)
@ -94,12 +116,41 @@ module type S0 = sig
(** Add the term to the congruence closure, if not present already.
Will be backtracked. *)
(** Actions available to the theory *)
type sat_actions = (Msat.void, lit, Msat.void, proof) Msat.acts
val create :
?on_merge:(t -> repr -> repr -> explanation -> unit) list ->
?on_new_term:(t -> repr -> term -> unit) list ->
?size:[`Small | `Big] ->
term_state ->
t
(** Create a new congruence closure. *)
val on_merge : t -> (t -> repr -> repr -> explanation -> unit) -> unit
(** Add a callback, to be called whenever two classes are merged *)
val on_new_term : t -> (t -> repr -> term -> unit) -> unit
(** Add a callback, to be called whenever a node is added *)
val merge_classes : t -> node -> node -> explanation -> unit
(** Merge the two given nodes with given explanation.
It must be a theory tautology that [expl ==> n1 = n2] *)
val th_data_get : t -> N.t -> (term, 'a) Key.t -> 'a option
(** Get data information for this particular representative *)
val th_data_add : t -> N.t -> (term, 'a) Key.t -> 'a -> unit
(** Add the given data to this node (or rather, to its representative).
This will be backtracked. *)
(* TODO: merge true/false?
val raise_conflict : CC.t -> CC.N.t -> CC.N.t -> Expl.t -> 'a
*)
val set_as_lit : t -> N.t -> lit -> unit
(** map the given node to a literal. *)
val add_term' : t -> term -> unit
(** Same as {!add_term} but ignore the result *)
val find_t : t -> term -> repr
(** Current representative of the term.
@raise Not_found if the term is not already {!add}-ed. *)
@ -108,17 +159,21 @@ module type S0 = sig
(** Add a sequence of terms to the congruence closure *)
val all_classes : t -> repr Sequence.t
(** All current classes *)
(** All current classes. This is costly, only use if there is no other solution *)
val assert_lit : t -> lit -> unit
(** Given a literal, assume it in the congruence closure and propagate
its consequences. Will be backtracked. *)
its consequences. Will be backtracked.
Useful for the theory combination or the SAT solver's functor *)
val assert_lits : t -> lit Sequence.t -> unit
(** Addition of many literals *)
val assert_eq : t -> term -> term -> lit list -> unit
(** merge the given terms with some explanations *)
(* TODO: remove and move into its own library as a micro theory *)
val assert_distinct : t -> term list -> neq:term -> lit -> unit
(** [assert_distinct l ~neq:u e] asserts all elements of [l] are distinct
because [lit] is true
@ -129,8 +184,10 @@ module type S0 = sig
Will use the [sat_actions] to propagate literals, declare conflicts, etc. *)
val push_level : t -> unit
(** Push backtracking level *)
val pop_levels : t -> int -> unit
(** Restore to state [n] calls to [push_level] earlier. Used during backtracking. *)
val mk_model : t -> model -> model
(** Enrich a model by mapping terms to their representative's value,
@ -140,11 +197,5 @@ module type S0 = sig
val check_invariants : t -> unit
val pp_full : t Fmt.printer
(**/**)
end
module type S = sig
include S0
end

View file

@ -71,8 +71,10 @@ module type S = sig
val create : Term.state -> t
(** Instantiate the theory's state *)
(* TODO: instead pass Congruence_closure.theory to [create]
val on_merge: t -> actions -> CC_eq_class.t -> CC_eq_class.t -> CC_expl.t -> unit
(** Called when two classes are merged *)
*)
val partial_check : t -> actions -> Lit.t Sequence.t -> unit
(** Called when a literal becomes true *)

View file

@ -29,7 +29,6 @@ type t = {
(** congruence closure *)
mutable theories : theory_state list;
(** Set of theories *)
new_merges: (Eq_class.t * Eq_class.t * Expl.t) Vec.t;
}
let[@inline] cc (t:t) = Lazy.force t.cc
@ -45,7 +44,6 @@ let assert_lits_ ~final (self:t) acts (lits:Lit.t Sequence.t) : unit =
(fun k->k "(@[<hv1>@{<green>th_combine.assume_lits@}%s@ %a@])"
(if final then "[final]" else "") (Util.pp_seq ~sep:"; " Lit.pp) lits);
(* transmit to CC *)
Vec.clear self.new_merges;
let cc = cc self in
if not final then (
CC.assert_lits cc lits;
@ -71,7 +69,6 @@ let assert_lits_ ~final (self:t) acts (lits:Lit.t Sequence.t) : unit =
theories self
(fun (Th_state ((module Th),st)) ->
(* give new merges, then call {final,partial}-check *)
Vec.iter (fun (r1,r2,e) -> Th.on_merge st acts r1 r2 e) self.new_merges;
if final then Th.final_check st acts lits else Th.partial_check st acts lits);
()
@ -123,10 +120,6 @@ let mk_model (self:t) lits : Model.t =
(** {2 Interface to Congruence Closure} *)
(* when CC decided to merge [r1] and [r2], notify theories *)
let[@inline] on_merge_from_cc (self:t) r1 r2 e : unit =
Vec.push self.new_merges (r1,r2,e)
(** {2 Main} *)
(* create a new theory combination *)
@ -134,11 +127,10 @@ let create () : t =
Log.debug 5 "th_combine.create";
let rec self = {
tst=Term.create ~size:1024 ();
new_merges=Vec.create();
cc = lazy (
(* lazily tie the knot *)
let on_merge = on_merge_from_cc self in
CC.create ~on_merge ~size:`Big self.tst;
(* TODO: pass theories *)
CC.create ~size:`Big self.tst;
);
theories = [];
} in