wip: new micro-theories in CC

This commit is contained in:
Simon Cruanes 2019-02-26 22:46:40 -06:00
parent 57147cea85
commit 342dba4533
19 changed files with 307 additions and 294 deletions

View file

@ -3,41 +3,62 @@ open CC_types
module type ARG = Congruence_closure_intf.ARG
module type S = Congruence_closure_intf.S
module type THEORY_DATA = Congruence_closure_intf.THEORY_DATA
module type THEORY_KEY = Congruence_closure_intf.THEORY_KEY
type ('t, 'a) theory_data = ('t,'a) Congruence_closure_intf.theory_data
module type KEY_IMPL = sig
include THEORY_DATA
exception Store of t
val id : int
end
(** Custom keys for theory data.
This imitates the classic tricks for heterogeneous maps
https://blog.janestreet.com/a-universal-type/
*)
It needs to form a commutative monoid where values are persistent so
they can be restored during backtracking.
*)
module Key = struct
type ('term, 'a) t = (module KEY_IMPL with type term = 'term and type t = 'a)
module type KEY_IMPL = sig
type term
type lit
type t
val id : int
val name : string
val pp : t Fmt.printer
val equal : t -> t -> bool
val merge : t -> t -> t
exception Store of t
end
type ('term,'lit,'a) t =
(module KEY_IMPL with type term = 'term and type lit = 'lit and type t = 'a)
let n_ = ref 0
let create (type term)(type d) (th:(term,d) theory_data) : (term,d) t =
let (module TH) = th in
let create (type term)(type lit)(type d)
?(pp=fun out _ -> Fmt.string out "<opaque>")
~name ~eq ~merge () : (term,lit,d) t =
let module K = struct
include TH
exception Store of d
type nonrec term = term
type nonrec lit = lit
type t = d
let id = !n_
let name = name
let pp = pp
let merge = merge
let equal = eq
exception Store of d
end in
incr n_;
(module K)
let id (module K : KEY_IMPL) = K.id
let[@inline] id
: type term lit a. (term,lit,a) t -> int
= fun (module K) -> K.id
let equal
: type a b term. (term,a) t -> (term,b) t -> bool
let[@inline] equal
: type term lit a b. (term,lit,a) t -> (term,lit,b) t -> bool
= fun (module K1) (module K2) -> K1.id = K2.id
let pp
: type term lit a. (term,lit,a) t Fmt.printer
= fun out (module K) -> Fmt.string out K.name
end
module Bits = CCBitField.Make()
@ -67,12 +88,12 @@ module Make(A: ARG) = struct
module T = A.Term
module Fun = A.Fun
module Key = Key
module IM = Map.Make(CCInt)
(** Map for theory data associated with representatives *)
module K_map = struct
type pair = Pair : (term, 'a) Key.t * exn -> pair
module IM = Map.Make(CCInt)
type 'a key = (term,lit,'a) Key.t
type pair = Pair : 'a key * exn -> pair
type t = pair IM.t
@ -80,20 +101,18 @@ module Make(A: ARG) = struct
let[@inline] mem k t = IM.mem (Key.id k) t
let is_empty = IM.is_empty
let find (type a) (k : (term,a) Key.t) (self:t) : a option =
let find (type a) (k : a key) (self:t) : a option =
let (module K) = k in
match IM.find K.id self with
| Pair (_, K.Store v) -> Some v
| _ -> None
| exception Not_found -> None
let add (type a) (k : (term,a) Key.t) (v:a) (self:t) : t =
let add (type a) (k : a key) (v:a) (self:t) : t =
let (module K) = k in
IM.add K.id (Pair (k, K.Store v)) self
let remove (type a) (k: (term,a) Key.t) self : t =
let remove (type a) (k: a key) self : t =
let (module K) = k in
IM.remove K.id self
@ -102,22 +121,23 @@ module Make(A: ARG) = struct
(fun p1 p2 ->
let Pair ((module K1), v1) = p1 in
let Pair ((module K2), v2) = p2 in
K1.id = K2.id &&
assert (K1.id = K2.id);
match v1, v2 with K1.Store v1, K1.Store v2 -> K1.equal v1 v2 | _ -> false)
m1 m2
let merge (m1:t) (m2:t) : t =
let merge ~f_both (m1:t) (m2:t) : t =
IM.merge
(fun _ p1 p2 ->
match p1, p2 with
| None, None -> None
| Some v, None
| None, Some v -> Some v
| Some (Pair ((module K1) as key1, v1)), Some (Pair (_, v2)) ->
match v1, v2 with
| Some (Pair ((module K1) as key1, pair1)), Some (Pair (_, pair2)) ->
match pair1, pair2 with
| K1.Store v1, K1.Store v2 ->
(* merge content *)
Some (Pair (key1, K1.Store (K1.merge v1 v2)))
f_both K1.id pair1 pair2; (* callback for checking compat *)
let v12 = K1.merge v1 v2 in (* merge content *)
Some (Pair (key1, K1.Store v12))
| _ -> assert false
)
m1 m2
@ -137,9 +157,6 @@ module Make(A: ARG) = struct
mutable n_as_lit: lit option; (* TODO: put into payload? and only in root? *)
mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *)
mutable n_th_data: K_map.t; (* theory data *)
(* TODO: make a micro theory and move this inside *)
mutable n_tags: (node * explanation) Util.Int_map.t;
(* "distinct" tags (i.e. set of `(distinct t1…tn)` terms this belongs to *)
}
and signature = (fun_, node, node list) view
@ -151,15 +168,13 @@ module Make(A: ARG) = struct
expl: explanation;
}
(* TODO: make this recursive (the list case) *)
(* atomic explanation in the congruence closure *)
and explanation =
| E_reduction (* by pure reduction, tautologically equal *)
| E_merge of node * node
| E_merges of (node * node) list (* caused by these merges *)
| E_congruence of node * node (* caused by normal congruence *)
| E_lit of lit (* because of this literal *)
| E_lits of lit list (* because of this (true) conjunction *)
| E_merge of node * node
| E_list of explanation list
| E_congruence of node * node (* caused by normal congruence *)
type repr = node
type conflict = lit list
@ -185,7 +200,6 @@ module Make(A: ARG) = struct
n_next=n;
n_size=1;
n_th_data=K_map.empty;
n_tags=Util.Int_map.empty;
} in
n
@ -217,30 +231,24 @@ module Make(A: ARG) = struct
module Expl = struct
type t = explanation
let pp out (e:explanation) = match e with
let rec pp out (e:explanation) = match e with
| E_reduction -> Fmt.string out "reduction"
| E_lit lit -> A.Lit.pp out lit
| E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2
| E_lits l -> CCFormat.Dump.list A.Lit.pp out l
| E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b
| E_merges l ->
Format.fprintf out "(@[<hv1>merges@ %a@])"
Fmt.(seq ~sep:(return "@ ") @@ within "[" "]" @@ hvbox @@
pair ~sep:(return " ~@ ") N.pp N.pp)
(Sequence.of_list l)
| E_list l ->
Format.fprintf out "(@[<hv1>and@ %a@])"
Fmt.(list ~sep:(return "@ ") @@ within "[" "]" @@ hvbox @@ pp) l
let mk_reduction : t = E_reduction
let[@inline] mk_congruence n1 n2 : t = E_congruence (n1,n2)
let[@inline] mk_merge a b : t = E_merge (a,b)
let[@inline] mk_merges = function
| [] -> mk_reduction
| [(a,b)] -> mk_merge a b
| l -> E_merges l
let[@inline] mk_lit l : t = E_lit l
let[@inline] mk_lits = function
let mk_list l =
match l with
| [] -> mk_reduction
| [x] -> mk_lit x
| l -> E_lits l
| [x] -> x
| l -> E_list l
end
(** A signature is a shallow term shape where immediate subterms
@ -290,7 +298,15 @@ module Make(A: ARG) = struct
type combine_task =
| CT_merge of node * node * explanation
| CT_distinct of node list * int * explanation
module type THEORY = sig
type cc
type data
val key_id : int
val key : (term,lit,data) Key.t
val on_merge : cc -> N.t -> data -> N.t -> data -> Expl.t -> unit
val on_new_term: cc -> term -> data option
end
type t = {
tst: term_state;
@ -307,8 +323,7 @@ module Make(A: ARG) = struct
pending: node Vec.t;
combine: combine_task Vec.t;
undo: (unit -> unit) Backtrack_stack.t;
mutable on_merge: (t -> repr -> repr -> explanation -> unit) list;
mutable on_new_term: (t -> repr -> term -> unit) list;
mutable theories: theory IM.t;
mutable ps_lits: lit list; (* TODO: thread it around instead? *)
(* proof state *)
ps_queue: (node*node) Vec.t;
@ -322,6 +337,10 @@ module Make(A: ARG) = struct
several times.
See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *)
and theory = (module THEORY with type cc = t)
type cc = t
let[@inline] size_ (r:repr) = r.n_size
let[@inline] true_ cc = Lazy.force cc.true_
let[@inline] false_ cc = Lazy.force cc.false_
@ -333,8 +352,10 @@ module Make(A: ARG) = struct
(* check if [t] is in the congruence closure.
Invariant: [in_cc t do_cc t => forall u subterm t, in_cc u] *)
let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t
(* FIXME
let on_merge cc f = cc.on_merge <- f :: cc.on_merge
let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term
*)
(* find representative, recursively *)
let[@unroll 2] rec find_rec (n:node) : repr =
@ -378,28 +399,6 @@ module Make(A: ARG) = struct
(Util.pp_seq ~sep:" " pp_n) (T_tbl.values cc.tbl)
(Util.pp_seq ~sep:" " pp_sig_e) (Sig_tbl.to_seq cc.signatures_tbl)
let th_data_get (_:t) (n:node) (key: _ Key.t) : _ option =
let n = find_ n in
K_map.find key n.n_th_data
(* update data for [n] *)
let th_data_add (type a) (self:t) (n:node) (key: (term,a) Key.t) (v:a) : unit =
let n = find_ n in
let map = n.n_th_data in
let old_v = K_map.find key map in
let v', is_diff = match old_v with
| None -> v, true
| Some old_v ->
let (module K) = key in
let v' = K.merge old_v v in
v', K.equal v v'
in
if is_diff then (
on_backtrack self (fun () -> n.n_th_data <- map);
);
n.n_th_data <- K_map.add key v' map;
()
(* compute up-to-date signature *)
let update_sig (s:signature) : Signature.t =
CC_types.map_view s
@ -506,34 +505,30 @@ module Make(A: ARG) = struct
(* TODO: turn this into a fold? *)
(* decompose explanation [e] of why [n1 = n2] *)
let decompose_explain cc (e:explanation) : unit =
let rec decompose_explain cc (e:explanation) : unit =
Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e);
begin match e with
| E_reduction -> ()
| E_congruence (n1, n2) ->
begin match n1.n_sig0, n2.n_sig0 with
| Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) ->
assert (Fun.equal f1 f2);
assert (List.length a1 = List.length a2);
List.iter2 (ps_add_obligation cc) a1 a2;
| Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) ->
assert (List.length a1 = List.length a2);
ps_add_obligation cc f1 f2;
List.iter2 (ps_add_obligation cc) a1 a2;
| Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) ->
ps_add_obligation cc a1 a2;
ps_add_obligation cc b1 b2;
ps_add_obligation cc c1 c2;
| _ ->
assert false
end
| E_lit lit -> ps_add_lit cc lit
| E_lits l -> List.iter (ps_add_lit cc) l
| E_merge (a,b) -> ps_add_obligation cc a b
| E_merges l ->
(* need to explain each merge in [l] *)
List.iter (fun (t,u) -> ps_add_obligation cc t u) l
end
match e with
| E_reduction -> ()
| E_congruence (n1, n2) ->
begin match n1.n_sig0, n2.n_sig0 with
| Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) ->
assert (Fun.equal f1 f2);
assert (List.length a1 = List.length a2);
List.iter2 (ps_add_obligation cc) a1 a2;
| Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) ->
assert (List.length a1 = List.length a2);
ps_add_obligation cc f1 f2;
List.iter2 (ps_add_obligation cc) a1 a2;
| Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) ->
ps_add_obligation cc a1 a2;
ps_add_obligation cc b1 b2;
ps_add_obligation cc c1 c2;
| _ ->
assert false
end
| E_lit lit -> ps_add_lit cc lit
| E_merge (a,b) -> ps_add_obligation cc a b
| E_list l -> List.iter (decompose_explain cc) l
(* explain why [a = parent_a], where [a -> ... -> parent_a] in the
proof forest *)
@ -575,6 +570,7 @@ module Make(A: ARG) = struct
decompose_explain cc e;
explain_loop cc
(* FIXME remove
(* add [tag] to [n], indicating that [n] is distinct from all the other
nodes tagged with [tag]
precond: [n] is a representative *)
@ -585,6 +581,7 @@ module Make(A: ARG) = struct
(fun () -> n.n_tags <- Util.Int_map.remove tag n.n_tags);
n.n_tags <- Util.Int_map.add tag (n,expl) n.n_tags;
)
*)
(* add a term *)
let [@inline] rec add_term_rec_ cc t : node =
@ -611,7 +608,16 @@ module Make(A: ARG) = struct
(* [n] might be merged with other equiv classes *)
push_pending cc n;
);
List.iter (fun f -> f cc n t) cc.on_new_term;
(* initial theory data *)
let th_map =
IM.fold
(fun _ (module Th: THEORY with type cc=cc) th_map ->
match Th.on_new_term cc t with
| None -> th_map
| Some v -> K_map.add Th.key v th_map)
cc.theories K_map.empty
in
n.n_th_data <- th_map;
n
(* compute the initial signature of the given node *)
@ -701,7 +707,6 @@ module Make(A: ARG) = struct
(* TODO: remove, once we have moved distinct to a theory *)
and[@inline] task_combine_ cc acts = function
| CT_merge (a,b,e_ab) -> task_merge_ cc acts a b e_ab
| CT_distinct (l,tag,e) -> task_distinct_ cc acts l tag e
(* main CC algo: merge equivalence classes in [st.combine].
@raise Exn_unsat if merge fails *)
@ -731,26 +736,6 @@ module Make(A: ARG) = struct
else if size_ ra > size_ rb then rb, ra
else ra, rb
in
(* TODO: instead call micro theories, including "distinct" *)
(* update set of tags the new node cannot be equal to *)
let new_tags =
Util.Int_map.union
(fun _i (n1,e1) (n2,e2) ->
(* both maps contain same tag [_i]. conflict clause:
[e1 & e2 & e_ab] impossible *)
Log.debugf 5
(fun k->k "(@[<hv>cc.merge.distinct_conflict@ :tag %d@ \
@[:r1 %a@ :e1 %a@]@ @[:r2 %a@ :e2 %a@]@ :e_ab %a@])"
_i N.pp n1 Expl.pp e1
N.pp n2 Expl.pp e2 Expl.pp e_ab);
let lits = explain_unfold cc e1 in
let lits = explain_unfold ~init:lits cc e2 in
let lits = explain_unfold ~init:lits cc e_ab in
let lits = explain_eq_n ~init:lits cc a n1 in
let lits = explain_eq_n ~init:lits cc b n2 in
raise_conflict cc acts lits)
ra.n_tags rb.n_tags
in
(* when merging terms with [true] or [false], possibly propagate them to SAT *)
let merge_bool r1 t1 r2 t2 =
if N.equal r1 (true_ cc) then (
@ -763,6 +748,35 @@ module Make(A: ARG) = struct
merge_bool rb b ra a;
(* perform [union r_from r_into] *)
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
(* call micro theories *)
begin
let th_into = r_into.n_th_data in
let th_from = r_from.n_th_data in
(* merge the two maps; if a key occurs in both, looks for theories with
this particular key *)
let th =
K_map.merge th_into th_from
~f_both:(fun id pair_into pair_from ->
match IM.find id cc.theories with
| (module Th : THEORY with type cc=t) ->
(* casting magic *)
let (module K) = Th.key in
begin match pair_into, pair_from with
| K.Store v_into, K.Store v_from ->
Log.debugf 15
(fun k->k "(@[cc.merge.th-on-merge@ :th %s@])" K.name);
(* FIXME: explanation is a=ra, e_ab, b=rb *)
Th.on_merge cc r_into v_into r_from v_from e_ab
| _ -> assert false
end
| exception Not_found -> ())
in
(* restore old data, if it changed *)
if not @@ K_map.equal th th_into then (
on_backtrack cc (fun () -> r_into.n_th_data <- th_into);
);
r_into.n_th_data <- th;
end;
begin
(* parents might have a different signature, check for collisions *)
N.iter_parents r_from
@ -773,7 +787,6 @@ module Make(A: ARG) = struct
assert (u.n_root == r_from);
u.n_root <- r_into);
(* now merge the classes *)
let r_into_old_tags = r_into.n_tags in
let r_into_old_next = r_into.n_next in
let r_from_old_next = r_from.n_next in
let r_into_old_parents = r_into.n_parents in
@ -786,11 +799,9 @@ module Make(A: ARG) = struct
N.pp r_from N.pp r_into);
r_into.n_next <- r_into_old_next;
r_from.n_next <- r_from_old_next;
r_into.n_tags <- r_into_old_tags;
r_into.n_parents <- r_into_old_parents;
N.iter_class_ r_from (fun u -> u.n_root <- r_from);
);
r_into.n_tags <- new_tags;
(* swap [into.next] and [from.next], merging the classes *)
r_into.n_next <- r_from_old_next;
r_from.n_next <- r_into_old_next;
@ -810,10 +821,9 @@ module Make(A: ARG) = struct
| _ -> assert false);
a.n_expl <- FL_some {next=b; expl=e_ab};
end;
(* notify listeners of the merge *)
List.iter (fun f -> f cc r_into r_from e_ab) cc.on_merge
)
(* FIXME: remove
and task_distinct_ cc acts (l:node list) tag expl : unit =
let l = List.map (fun n -> n, find_ n) l in
let coll =
@ -832,6 +842,7 @@ module Make(A: ARG) = struct
(* put a tag on all equivalence classes, that will make their merge fail *)
List.iter (fun (_,n) -> add_tag_n cc n tag expl) l
end
*)
(* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1]
in the equiv class of [r1] that is a known literal back to the SAT solver
@ -864,6 +875,61 @@ module Make(A: ARG) = struct
acts.Msat.acts_propagate lit reason
| _ -> ())
module Theory = struct
type cc = t
type t = theory
type 'a key = (term,lit,'a) Key.t
(* raise a conflict *)
let raise_conflict cc _n1 _n2 expl =
Log.debugf 5
(fun k->k "(@[cc.theory.raise-conflict@ :n1 %a@ :n2 %a@ :expl %a@])"
N.pp _n1 N.pp _n2 Expl.pp expl);
merge_classes cc (true_ cc) (false_ cc) expl
let merge cc n1 n2 expl =
Log.debugf 5
(fun k->k "(@[cc.theory.merge@ :n1 %a@ :n2 %a@ :expl %a@])" N.pp n1 N.pp n2 Expl.pp expl);
merge_classes cc n1 n2 expl
let add_term = add_term
let get_data _cc n key =
assert (N.is_root n);
K_map.find key n.n_th_data
(* FIXME: call micro theory here? in case of merge *)
(* update data for [n] *)
let add_data (type a) (self:cc) (n:node) (key: a key) (v:a) : unit =
let n = find_ n in
let map = n.n_th_data in
let old_v = K_map.find key map in
let v', is_diff = match old_v with
| None -> v, true
| Some old_v ->
let (module K) = key in
let v' = K.merge old_v v in
v', K.equal v v'
in
if is_diff then (
on_backtrack self (fun () -> n.n_th_data <- map);
);
n.n_th_data <- K_map.add key v' map;
()
let make (type a) ~(key:a key) ~on_merge ~on_new_term () : t =
let module Th = struct
type nonrec cc = cc
type data = a
let key = key
let key_id = Key.id key
let on_merge = on_merge
let on_new_term = on_new_term
end in
(module Th : THEORY with type cc=cc)
end
let check_invariants_ (cc:t) =
Log.debug 5 "(cc.check-invariants)";
Log.debugf 15 (fun k-> k "%a" pp_full cc);
@ -943,11 +1009,12 @@ module Make(A: ARG) = struct
Sequence.iter (assert_lit cc) lits
let assert_eq cc t1 t2 (e:lit list) : unit =
let expl = Expl.mk_lits e in
let expl = Expl.mk_list @@ List.rev_map Expl.mk_lit e in
let n1 = add_term cc t1 in
let n2 = add_term cc t2 in
merge_classes cc n1 n2 expl
(* FIXME: remove
(* generative tag used to annotate classes that can't be merged *)
let distinct_tag_ = ref 0
@ -958,14 +1025,23 @@ module Make(A: ARG) = struct
(fun k->k "(@[cc.assert_distinct@ (@[%a@])@ :tag %d@])" (Util.pp_list T.pp) l tag);
let l = List.map (add_term cc) l in
Vec.push cc.combine @@ CT_distinct (l, tag, Expl.mk_lit lit)
*)
let create ?(on_merge=[]) ?(on_new_term=[]) ?(size=`Big) (tst:term_state) : t =
let add_th (self:t) (th:theory) : unit =
let (module Th) = th in
if IM.mem Th.key_id self.theories then (
Error.errorf "attempt to add two theories with key %a" Key.pp Th.key
);
Log.debugf 3 (fun k->k "(@[@{<green>cc.add-theory@} %a@])" Key.pp Th.key);
self.theories <- IM.add Th.key_id th self.theories
let create ?th:(theories=[]) ?(size=`Big) (tst:term_state) : t =
let size = match size with `Small -> 128 | `Big -> 2048 in
let rec cc = {
tst;
tbl = T_tbl.create size;
signatures_tbl = Sig_tbl.create size;
on_merge; on_new_term;
theories=IM.empty;
pending=Vec.create();
combine=Vec.create();
ps_lits=[];
@ -981,6 +1057,7 @@ module Make(A: ARG) = struct
in
ignore (Lazy.force true_ : node);
ignore (Lazy.force false_ : node);
List.iter (add_th cc) theories; (* now add theories *)
cc
let[@inline] find_t cc t : repr =

View file

@ -2,11 +2,8 @@
module type ARG = Congruence_closure_intf.ARG
module type S = Congruence_closure_intf.S
module type THEORY_DATA = Congruence_closure_intf.THEORY_DATA
module type THEORY_KEY = Congruence_closure_intf.THEORY_KEY
type ('t, 'a) theory_data = ('t,'a) Congruence_closure_intf.theory_data
module Key : THEORY_KEY
module Make(A: ARG)

View file

@ -1,36 +1,35 @@
module type ARG = CC_types.FULL
(** Data stored by a theory for its own terms.
It needs to form a commutative monoid where values can be unmerged upon
backtracking.
*)
module type THEORY_DATA = sig
type term
type t
val empty : t
val equal : t -> t -> bool
(** Equality. This is used to optimize backtracking info. *)
val merge : t -> t -> t
(** [merge d1 d2] is called when merging classes with data [d1] and [d2]
respectively. The theory should already have checked that the merge
is compatible, and this produces the combined data for terms in the
merged class. *)
end
type ('t, 'a) theory_data = (module THEORY_DATA with type term = 't and type t = 'a)
module type THEORY_KEY = sig
type ('t, 'a) t
(** An access key for theories that use terms ['t] and which have
per-class data ['a] *)
type ('term,'lit,'a) t
(** An access key for theories which have per-class data ['a] *)
val create : ('t, 'a) theory_data -> ('t, 'a) t
(** Generative creation of keys for the given theory data. *)
val create :
?pp:'a Fmt.printer ->
name:string ->
eq:('a -> 'a -> bool) ->
merge:('a -> 'a -> 'a) ->
unit ->
('term,'lit,'a) t
(** Generative creation of keys for the given theory data.
@param eq : Equality. This is used to optimize backtracking info.
@param merge :
[merge d1 d2] is called when merging classes with data [d1] and [d2]
respectively. The theory should already have checked that the merge
is compatible, and this produces the combined data for terms in the
merged class.
@param name name of the theory which owns this data
@param pp a printer for the data
*)
val equal : ('t,'lit,_) t -> ('t,'lit,_) t -> bool
(** Checks if two keys are equal (generatively) *)
val pp : _ t Fmt.printer
(** Prints the name of the key. *)
end
module type S = sig
@ -87,12 +86,9 @@ module type S = sig
type t
val pp : t Fmt.printer
val mk_reduction : t
val mk_congruence : N.t -> N.t -> t
val mk_merge : N.t -> N.t -> t
val mk_merges : (N.t * N.t) list -> t
val mk_lit : lit -> t
val mk_lits : lit list -> t
val mk_list : t list -> t
end
type node = N.t
@ -119,34 +115,52 @@ module type S = sig
(** Actions available to the theory *)
type sat_actions = (Msat.void, lit, Msat.void, proof) Msat.acts
module Theory : sig
type cc = t
type t
type 'a key = (term,lit,'a) Key.t
val raise_conflict : cc -> Expl.t -> unit
(** Raise a conflict with the given explanation
it must be a theory tautology that [expl ==> absurd].
To be used in theories. *)
val merge : cc -> N.t -> N.t -> Expl.t -> unit
(** Merge these two nodes given this explanation.
It must be a theory tautology that [expl ==> n1 = n2].
To be used in theories. *)
val add_term : cc -> term -> N.t
(** Add/retrieve node for this term.
To be used in theories *)
val get_data : cc -> N.t -> 'a key -> 'a option
(** Get data information for this particular representative *)
val add_data : cc -> N.t -> 'a key -> 'a -> unit
(** Add data to this particular representative. Will be backtracked. *)
val make :
key:'a key ->
on_merge:(cc -> N.t -> 'a -> N.t -> 'a -> Expl.t -> unit) ->
on_new_term:(cc -> term -> 'a option) ->
unit ->
t
(** Build a micro theory. It can use the callbacks above. *)
end
val create :
?on_merge:(t -> repr -> repr -> explanation -> unit) list ->
?on_new_term:(t -> repr -> term -> unit) list ->
?th:Theory.t list ->
?size:[`Small | `Big] ->
term_state ->
t
(** Create a new congruence closure. *)
val on_merge : t -> (t -> repr -> repr -> explanation -> unit) -> unit
(** Add a callback, to be called whenever two classes are merged *)
val on_new_term : t -> (t -> repr -> term -> unit) -> unit
(** Add a callback, to be called whenever a node is added *)
val merge_classes : t -> node -> node -> explanation -> unit
(** Merge the two given nodes with given explanation.
It must be a theory tautology that [expl ==> n1 = n2] *)
val th_data_get : t -> N.t -> (term, 'a) Key.t -> 'a option
(** Get data information for this particular representative *)
val th_data_add : t -> N.t -> (term, 'a) Key.t -> 'a -> unit
(** Add the given data to this node (or rather, to its representative).
This will be backtracked. *)
(* TODO: merge true/false?
val raise_conflict : CC.t -> CC.N.t -> CC.N.t -> Expl.t -> 'a
*)
val add_th : t -> Theory.t -> unit
(** Add a (micro) theory to the congruence closure.
@raise Error.Error if there is already a theory with
the same key. *)
val set_as_lit : t -> N.t -> lit -> unit
(** map the given node to a literal. *)
@ -173,11 +187,12 @@ module type S = sig
val assert_eq : t -> term -> term -> lit list -> unit
(** merge the given terms with some explanations *)
(* TODO: remove and move into its own library as a micro theory *)
(* TODO: remove and move into its own library as a micro theory
val assert_distinct : t -> term list -> neq:term -> lit -> unit
(** [assert_distinct l ~neq:u e] asserts all elements of [l] are distinct
because [lit] is true
precond: [u = distinct l] *)
*)
val check : t -> sat_actions -> unit
(** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc.

View file

@ -16,6 +16,7 @@ module type S = Congruence_closure.S
module Mini_cc = Mini_cc
module Congruence_closure = Congruence_closure
module Key = Congruence_closure.Key
module Make = Congruence_closure.Make

View file

@ -150,8 +150,8 @@ let eval (m:t) (t:Term.t) : Value.t option =
let b = aux b in
if Value.equal a b then Value.true_ else Value.false_
| App_cst (c, args) ->
begin try Term.Map.find t m.values
with Not_found ->
try Term.Map.find t m.values
with Not_found ->
match Cst.view c with
| Cst_def udef ->
(* use builtin interpretation function *)
@ -168,7 +168,6 @@ let eval (m:t) (t:Term.t) : Value.t option =
| exception Not_found ->
raise No_value (* no particular interpretation *)
end
end
in
try Some (aux t)
with No_value -> None

View file

@ -13,9 +13,12 @@ module Lit = Lit
module Theory_combine = Theory_combine
module Theory = Theory
module Solver = Solver
module CC = CC
module Solver_types = Solver_types
type theory = Theory.t
(**/**)
module Vec = Msat.Vec
module Log = Msat.Log

View file

@ -175,9 +175,10 @@ let assume (self:t) (c:Lit.t IArray.t) : unit =
let c = IArray.to_array_map (Sat_solver.make_atom sat) c in
Sat_solver.add_clause_a sat c Proof_default
(* TODO: remove? use a special constant + micro theory instead? *)
(* TODO: remove? use a special constant + micro theory instead?
let[@inline] assume_distinct self l ~neq lit : unit =
CC.assert_distinct (cc self) l lit ~neq
*)
let check_model (_s:t) : unit =
Log.debug 1 "(smt.solver.check-model)";

View file

@ -55,7 +55,9 @@ val mk_atom_t : t -> ?sign:bool -> Term.t -> Atom.t
val assume : t -> Lit.t IArray.t -> unit
(* TODO: use the theory instead
val assume_distinct : t -> Term.t list -> neq:Term.t -> Lit.t -> unit
*)
val solve :
?on_exit:(unit -> unit) list ->

View file

@ -23,16 +23,11 @@ module CC_expl = CC.Expl
(** Actions available to a theory during its lifetime *)
module type ACTIONS = sig
val cc : CC.t
val raise_conflict: conflict -> 'a
(** Give a conflict clause to the solver *)
val propagate_eq: Term.t -> Term.t -> Lit.t list -> unit
(** Propagate an equality [t = u] because [e].
TODO: use [CC.Expl] instead, with lit/merge constructors *)
val propagate_distinct: Term.t list -> neq:Term.t -> Lit.t -> unit
(** Propagate a [distinct l] because [e] (where [e = neq] *)
val propagate: Lit.t -> (unit -> Lit.t list) -> unit
(** Propagate a boolean using a unit clause.
[expl => lit] must be a theory lemma, that is, a T-tautology *)
@ -48,16 +43,6 @@ module type ACTIONS = sig
val add_persistent_axiom: Lit.t list -> unit
(** Add toplevel clause to the SAT solver. This clause will
not be backtracked. *)
val cc_add_term: Term.t -> CC_eq_class.t
(** add/get term to the congruence closure *)
val cc_find: CC_eq_class.t -> CC_eq_class.t
(** Find representative of this in the congruence closure *)
val cc_all_classes: CC_eq_class.t Sequence.t
(** All current equivalence classes
(caution: linear in the number of terms existing in the congruence closure) *)
end
type actions = (module ACTIONS)

View file

@ -51,9 +51,8 @@ let assert_lits_ ~final (self:t) acts (lits:Lit.t Sequence.t) : unit =
(* transmit to theories. *)
CC.check cc acts;
let module A = struct
let cc = cc
let[@inline] raise_conflict c : 'a = acts.Msat.acts_raise_conflict c Proof_default
let[@inline] propagate_eq t u expl : unit = CC.assert_eq cc t u expl
let propagate_distinct ts ~neq expl = CC.assert_distinct cc ts ~neq expl
let[@inline] propagate p cs : unit =
acts.Msat.acts_propagate p (Msat.Consequence (fun () -> cs(), Proof_default))
let[@inline] propagate_l p cs : unit = propagate p (fun()->cs)
@ -61,9 +60,6 @@ let assert_lits_ ~final (self:t) acts (lits:Lit.t Sequence.t) : unit =
acts.Msat.acts_add_clause ~keep:false lits Proof_default
let[@inline] add_persistent_axiom lits : unit =
acts.Msat.acts_add_clause ~keep:true lits Proof_default
let[@inline] cc_add_term t = CC.add_term cc t
let[@inline] cc_find t = CC.find cc t
let cc_all_classes = CC.all_classes cc
end in
let acts = (module A : Theory.ACTIONS) in
theories self

View file

@ -8,6 +8,7 @@ type 'a or_error = ('a, string) CCResult.t
module E = CCResult
module A = Ast
module Form = Sidekick_th_bool.Bool_term
module Distinct = Sidekick_th_distinct
module Fmt = CCFormat
module Dot = Msat_backend.Dot.Make(Solver.Sat_solver)(Msat_backend.Dot.Default(Solver.Sat_solver))
@ -137,7 +138,7 @@ module Conv = struct
in
Form.and_l tst (curry_eq l)
| A.Op (A.Distinct, l) ->
Form.distinct_l tst @@ List.map (aux subst) l
Distinct.distinct_l tst @@ List.map (aux subst) l
| A.Not f -> Form.not_ tst (aux subst f)
| A.Bool true -> Term.true_ tst
| A.Bool false -> Term.false_ tst

View file

@ -3,12 +3,8 @@
(name sidekick_smtlib)
(public_name sidekick.smtlib)
(libraries containers zarith msat sidekick.smt sidekick.util
sidekick.smt.th-bool msat.backend)
(flags :standard -w +a-4-42-44-48-50-58-32-60@8
-safe-string -color always -open Sidekick_util)
(ocamlopt_flags :standard -O3 -color always -bin-annot
-unbox-closures -unbox-closures-factor 20)
)
sidekick.smt.th-bool sidekick.smt.th-distinct msat.backend)
(flags :standard -open Sidekick_util))
(menhir (modules Parser))

View file

@ -3,11 +3,9 @@
type 'a view =
| B_not of 'a
| B_eq of 'a * 'a
| B_and of 'a IArray.t
| B_or of 'a IArray.t
| B_imply of 'a IArray.t * 'a
| B_distinct of 'a IArray.t
| B_atom of 'a
(** {2 Interface for a representation of boolean terms} *)

View file

@ -18,7 +18,6 @@ let id_not = ID.make "not"
let id_and = ID.make "and"
let id_or = ID.make "or"
let id_imply = ID.make "=>"
let id_distinct = ID.make "distinct"
let equal = T.equal
let hash = T.hash
@ -34,17 +33,14 @@ let view_id cst_id args =
(* conclusion is stored first *)
let len = IArray.length args in
B_imply (IArray.sub args 1 (len-1), IArray.get args 0)
) else if ID.equal cst_id id_distinct then (
B_distinct args
) else (
raise_notrace Not_a_th_term
)
let view_as_bool (t:T.t) : T.t view =
match T.view t with
| Eq (a,b) -> B_eq (a,b)
| App_cst ({cst_id; _}, args) ->
begin try view_id cst_id args with Not_a_th_term -> B_atom t end
(try view_id cst_id args with Not_a_th_term -> B_atom t)
| _ -> B_atom t
module C = struct
@ -69,14 +65,7 @@ module C = struct
| B_imply (_, V_bool true) -> Value.true_
| B_imply (a,_) when IArray.exists Value.is_false a -> Value.true_
| B_imply (a,b) when IArray.for_all Value.is_bool a && Value.is_bool b -> Value.false_
| B_eq (a,b) -> Value.bool @@ Value.equal a b
| B_atom v -> v
| B_distinct a ->
if
Sequence.diagonal (IArray.to_seq a)
|> Sequence.for_all (fun (x,y) -> not @@ Value.equal x y)
then Value.true_
else Value.false_
| B_not _ | B_and _ | B_or _ | B_imply _
-> Error.errorf "non boolean value in boolean connective"
@ -92,7 +81,6 @@ module C = struct
let and_ = mk_cst id_and
let or_ = mk_cst id_or
let imply = mk_cst id_imply
let distinct = mk_cst id_distinct
end
let as_id id (t:T.t) : T.t IArray.t option =
@ -152,20 +140,9 @@ let imply_l st xs y = match xs with
let imply st a b = imply_a st (IArray.singleton a) b
let distinct st a =
if IArray.length a <= 1
then T.true_ st
else T.app_cst st C.distinct a
let distinct_l st = function
| [] | [_] -> T.true_ st
| xs -> distinct st (IArray.of_list xs)
let make st = function
| B_atom t -> t
| B_eq (a,b) -> T.eq st a b
| B_and l -> and_a st l
| B_or l -> or_a st l
| B_imply (a,b) -> imply_a st a b
| B_not t -> not_ st t
| B_distinct l -> distinct st l

View file

@ -15,8 +15,6 @@ val imply_a : state -> term IArray.t -> term -> term
val imply_l : state -> term list -> term -> term
val eq : state -> term -> term -> term
val neq : state -> term -> term -> term
val distinct : state -> term IArray.t -> term
val distinct_l : state -> term list -> term
val and_a : state -> term IArray.t -> term
val and_l : state -> term list -> term
val or_a : state -> term IArray.t -> term

View file

@ -9,11 +9,9 @@ module Th_dyn_tseitin = Th_dyn_tseitin
type 'a view = 'a Intf.view =
| B_not of 'a
| B_eq of 'a * 'a
| B_and of 'a IArray.t
| B_or of 'a IArray.t
| B_imply of 'a IArray.t * 'a
| B_distinct of 'a IArray.t
| B_atom of 'a
module type BOOL_TERM = Intf.BOOL_TERM

View file

@ -15,14 +15,7 @@ module Make(Term : ARG) = struct
type term = Term.t
module T_tbl = CCHashtbl.Make(Term)
module Lit = struct
include Sidekick_smt.Lit
let eq tst a b = atom tst ~sign:true @@ Term.make tst (B_eq (a,b))
let neq tst a b = neg @@ eq tst a b
end
let pp_c out c = Fmt.fprintf out "(@[<hv>%a@])" (Util.pp_list Lit.pp) c
module Lit = Sidekick_smt.Lit
type t = {
tst: Term.state;
@ -39,22 +32,7 @@ module Make(Term : ARG) = struct
in
match v with
| B_not _ -> assert false (* normalized *)
| B_atom _ | B_eq _ -> () (* CC will manage *)
| B_distinct l ->
let l = IArray.to_list l in
if Lit.sign lit then (
A.propagate_distinct l ~neq:lit_t lit
) else if final && not @@ expanded () then (
(* add clause [distinct t1…tn _{i,j>i} t_i=j] *)
let c =
Sequence.diagonal_l l
|> Sequence.map (fun (t,u) -> Lit.eq self.tst t u)
|> Sequence.to_rev_list
in
let c = Lit.neg lit :: c in
Log.debugf 5 (fun k->k "(@[tseitin.distinct.case-split@ %a@])" pp_c c);
add_axiom c
)
| B_atom _ -> () (* CC will manage *)
| B_and subs ->
if Lit.sign lit then (
(* propagate [lit => subs_i] *)
@ -105,7 +83,7 @@ module Make(Term : ARG) = struct
(fun lit ->
let t = Lit.term lit in
match Term.view_as_bool t with
| B_atom _ | B_eq _ -> ()
| B_atom _ -> ()
| v -> tseitin ~final self acts lit t v)
let partial_check (self:t) acts (lits:Lit.t Sequence.t) =

View file

@ -12,11 +12,5 @@ module type ARG = Bool_intf.BOOL_TERM
module Make(Term : ARG) : sig
type term = Term.t
module Lit : sig
type t = Sidekick_smt.Lit.t
val eq : Term.state -> term -> term -> t
val neq : Term.state -> term -> term -> t
end
val th : Sidekick_smt.Theory.t
end

View file

@ -2,8 +2,5 @@
(name Sidekick_th_bool)
(public_name sidekick.smt.th-bool)
(libraries containers sidekick.smt)
(flags :standard -w +a-4-44-48-58-60@8
-color always -safe-string -short-paths -open Sidekick_util)
(ocamlopt_flags :standard -O3 -color always
-unbox-closures -unbox-closures-factor 20))
(flags :standard -open Sidekick_util))