wip: new micro-theories in CC

This commit is contained in:
Simon Cruanes 2019-02-26 22:46:40 -06:00
parent 57147cea85
commit 342dba4533
19 changed files with 307 additions and 294 deletions

View file

@ -3,41 +3,62 @@ open CC_types
module type ARG = Congruence_closure_intf.ARG module type ARG = Congruence_closure_intf.ARG
module type S = Congruence_closure_intf.S module type S = Congruence_closure_intf.S
module type THEORY_DATA = Congruence_closure_intf.THEORY_DATA
module type THEORY_KEY = Congruence_closure_intf.THEORY_KEY module type THEORY_KEY = Congruence_closure_intf.THEORY_KEY
type ('t, 'a) theory_data = ('t,'a) Congruence_closure_intf.theory_data
module type KEY_IMPL = sig
include THEORY_DATA
exception Store of t
val id : int
end
(** Custom keys for theory data. (** Custom keys for theory data.
This imitates the classic tricks for heterogeneous maps This imitates the classic tricks for heterogeneous maps
https://blog.janestreet.com/a-universal-type/ https://blog.janestreet.com/a-universal-type/
*)
It needs to form a commutative monoid where values are persistent so
they can be restored during backtracking.
*)
module Key = struct module Key = struct
type ('term, 'a) t = (module KEY_IMPL with type term = 'term and type t = 'a) module type KEY_IMPL = sig
type term
type lit
type t
val id : int
val name : string
val pp : t Fmt.printer
val equal : t -> t -> bool
val merge : t -> t -> t
exception Store of t
end
type ('term,'lit,'a) t =
(module KEY_IMPL with type term = 'term and type lit = 'lit and type t = 'a)
let n_ = ref 0 let n_ = ref 0
let create (type term)(type d) (th:(term,d) theory_data) : (term,d) t = let create (type term)(type lit)(type d)
let (module TH) = th in ?(pp=fun out _ -> Fmt.string out "<opaque>")
~name ~eq ~merge () : (term,lit,d) t =
let module K = struct let module K = struct
include TH type nonrec term = term
exception Store of d type nonrec lit = lit
type t = d
let id = !n_ let id = !n_
let name = name
let pp = pp
let merge = merge
let equal = eq
exception Store of d
end in end in
incr n_; incr n_;
(module K) (module K)
let id (module K : KEY_IMPL) = K.id let[@inline] id
: type term lit a. (term,lit,a) t -> int
= fun (module K) -> K.id
let equal let[@inline] equal
: type a b term. (term,a) t -> (term,b) t -> bool : type term lit a b. (term,lit,a) t -> (term,lit,b) t -> bool
= fun (module K1) (module K2) -> K1.id = K2.id = fun (module K1) (module K2) -> K1.id = K2.id
let pp
: type term lit a. (term,lit,a) t Fmt.printer
= fun out (module K) -> Fmt.string out K.name
end end
module Bits = CCBitField.Make() module Bits = CCBitField.Make()
@ -67,12 +88,12 @@ module Make(A: ARG) = struct
module T = A.Term module T = A.Term
module Fun = A.Fun module Fun = A.Fun
module Key = Key module Key = Key
module IM = Map.Make(CCInt)
(** Map for theory data associated with representatives *) (** Map for theory data associated with representatives *)
module K_map = struct module K_map = struct
type pair = Pair : (term, 'a) Key.t * exn -> pair type 'a key = (term,lit,'a) Key.t
module IM = Map.Make(CCInt) type pair = Pair : 'a key * exn -> pair
type t = pair IM.t type t = pair IM.t
@ -80,20 +101,18 @@ module Make(A: ARG) = struct
let[@inline] mem k t = IM.mem (Key.id k) t let[@inline] mem k t = IM.mem (Key.id k) t
let is_empty = IM.is_empty let find (type a) (k : a key) (self:t) : a option =
let find (type a) (k : (term,a) Key.t) (self:t) : a option =
let (module K) = k in let (module K) = k in
match IM.find K.id self with match IM.find K.id self with
| Pair (_, K.Store v) -> Some v | Pair (_, K.Store v) -> Some v
| _ -> None | _ -> None
| exception Not_found -> None | exception Not_found -> None
let add (type a) (k : (term,a) Key.t) (v:a) (self:t) : t = let add (type a) (k : a key) (v:a) (self:t) : t =
let (module K) = k in let (module K) = k in
IM.add K.id (Pair (k, K.Store v)) self IM.add K.id (Pair (k, K.Store v)) self
let remove (type a) (k: (term,a) Key.t) self : t = let remove (type a) (k: a key) self : t =
let (module K) = k in let (module K) = k in
IM.remove K.id self IM.remove K.id self
@ -102,22 +121,23 @@ module Make(A: ARG) = struct
(fun p1 p2 -> (fun p1 p2 ->
let Pair ((module K1), v1) = p1 in let Pair ((module K1), v1) = p1 in
let Pair ((module K2), v2) = p2 in let Pair ((module K2), v2) = p2 in
K1.id = K2.id && assert (K1.id = K2.id);
match v1, v2 with K1.Store v1, K1.Store v2 -> K1.equal v1 v2 | _ -> false) match v1, v2 with K1.Store v1, K1.Store v2 -> K1.equal v1 v2 | _ -> false)
m1 m2 m1 m2
let merge (m1:t) (m2:t) : t = let merge ~f_both (m1:t) (m2:t) : t =
IM.merge IM.merge
(fun _ p1 p2 -> (fun _ p1 p2 ->
match p1, p2 with match p1, p2 with
| None, None -> None | None, None -> None
| Some v, None | Some v, None
| None, Some v -> Some v | None, Some v -> Some v
| Some (Pair ((module K1) as key1, v1)), Some (Pair (_, v2)) -> | Some (Pair ((module K1) as key1, pair1)), Some (Pair (_, pair2)) ->
match v1, v2 with match pair1, pair2 with
| K1.Store v1, K1.Store v2 -> | K1.Store v1, K1.Store v2 ->
(* merge content *) f_both K1.id pair1 pair2; (* callback for checking compat *)
Some (Pair (key1, K1.Store (K1.merge v1 v2))) let v12 = K1.merge v1 v2 in (* merge content *)
Some (Pair (key1, K1.Store v12))
| _ -> assert false | _ -> assert false
) )
m1 m2 m1 m2
@ -137,9 +157,6 @@ module Make(A: ARG) = struct
mutable n_as_lit: lit option; (* TODO: put into payload? and only in root? *) mutable n_as_lit: lit option; (* TODO: put into payload? and only in root? *)
mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *) mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *)
mutable n_th_data: K_map.t; (* theory data *) mutable n_th_data: K_map.t; (* theory data *)
(* TODO: make a micro theory and move this inside *)
mutable n_tags: (node * explanation) Util.Int_map.t;
(* "distinct" tags (i.e. set of `(distinct t1…tn)` terms this belongs to *)
} }
and signature = (fun_, node, node list) view and signature = (fun_, node, node list) view
@ -151,15 +168,13 @@ module Make(A: ARG) = struct
expl: explanation; expl: explanation;
} }
(* TODO: make this recursive (the list case) *)
(* atomic explanation in the congruence closure *) (* atomic explanation in the congruence closure *)
and explanation = and explanation =
| E_reduction (* by pure reduction, tautologically equal *) | E_reduction (* by pure reduction, tautologically equal *)
| E_merge of node * node
| E_merges of (node * node) list (* caused by these merges *)
| E_congruence of node * node (* caused by normal congruence *)
| E_lit of lit (* because of this literal *) | E_lit of lit (* because of this literal *)
| E_lits of lit list (* because of this (true) conjunction *) | E_merge of node * node
| E_list of explanation list
| E_congruence of node * node (* caused by normal congruence *)
type repr = node type repr = node
type conflict = lit list type conflict = lit list
@ -185,7 +200,6 @@ module Make(A: ARG) = struct
n_next=n; n_next=n;
n_size=1; n_size=1;
n_th_data=K_map.empty; n_th_data=K_map.empty;
n_tags=Util.Int_map.empty;
} in } in
n n
@ -217,30 +231,24 @@ module Make(A: ARG) = struct
module Expl = struct module Expl = struct
type t = explanation type t = explanation
let pp out (e:explanation) = match e with let rec pp out (e:explanation) = match e with
| E_reduction -> Fmt.string out "reduction" | E_reduction -> Fmt.string out "reduction"
| E_lit lit -> A.Lit.pp out lit | E_lit lit -> A.Lit.pp out lit
| E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2 | E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2
| E_lits l -> CCFormat.Dump.list A.Lit.pp out l
| E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b | E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b
| E_merges l -> | E_list l ->
Format.fprintf out "(@[<hv1>merges@ %a@])" Format.fprintf out "(@[<hv1>and@ %a@])"
Fmt.(seq ~sep:(return "@ ") @@ within "[" "]" @@ hvbox @@ Fmt.(list ~sep:(return "@ ") @@ within "[" "]" @@ hvbox @@ pp) l
pair ~sep:(return " ~@ ") N.pp N.pp)
(Sequence.of_list l)
let mk_reduction : t = E_reduction let mk_reduction : t = E_reduction
let[@inline] mk_congruence n1 n2 : t = E_congruence (n1,n2) let[@inline] mk_congruence n1 n2 : t = E_congruence (n1,n2)
let[@inline] mk_merge a b : t = E_merge (a,b) let[@inline] mk_merge a b : t = E_merge (a,b)
let[@inline] mk_merges = function
| [] -> mk_reduction
| [(a,b)] -> mk_merge a b
| l -> E_merges l
let[@inline] mk_lit l : t = E_lit l let[@inline] mk_lit l : t = E_lit l
let[@inline] mk_lits = function let mk_list l =
match l with
| [] -> mk_reduction | [] -> mk_reduction
| [x] -> mk_lit x | [x] -> x
| l -> E_lits l | l -> E_list l
end end
(** A signature is a shallow term shape where immediate subterms (** A signature is a shallow term shape where immediate subterms
@ -290,7 +298,15 @@ module Make(A: ARG) = struct
type combine_task = type combine_task =
| CT_merge of node * node * explanation | CT_merge of node * node * explanation
| CT_distinct of node list * int * explanation
module type THEORY = sig
type cc
type data
val key_id : int
val key : (term,lit,data) Key.t
val on_merge : cc -> N.t -> data -> N.t -> data -> Expl.t -> unit
val on_new_term: cc -> term -> data option
end
type t = { type t = {
tst: term_state; tst: term_state;
@ -307,8 +323,7 @@ module Make(A: ARG) = struct
pending: node Vec.t; pending: node Vec.t;
combine: combine_task Vec.t; combine: combine_task Vec.t;
undo: (unit -> unit) Backtrack_stack.t; undo: (unit -> unit) Backtrack_stack.t;
mutable on_merge: (t -> repr -> repr -> explanation -> unit) list; mutable theories: theory IM.t;
mutable on_new_term: (t -> repr -> term -> unit) list;
mutable ps_lits: lit list; (* TODO: thread it around instead? *) mutable ps_lits: lit list; (* TODO: thread it around instead? *)
(* proof state *) (* proof state *)
ps_queue: (node*node) Vec.t; ps_queue: (node*node) Vec.t;
@ -322,6 +337,10 @@ module Make(A: ARG) = struct
several times. several times.
See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *) See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *)
and theory = (module THEORY with type cc = t)
type cc = t
let[@inline] size_ (r:repr) = r.n_size let[@inline] size_ (r:repr) = r.n_size
let[@inline] true_ cc = Lazy.force cc.true_ let[@inline] true_ cc = Lazy.force cc.true_
let[@inline] false_ cc = Lazy.force cc.false_ let[@inline] false_ cc = Lazy.force cc.false_
@ -333,8 +352,10 @@ module Make(A: ARG) = struct
(* check if [t] is in the congruence closure. (* check if [t] is in the congruence closure.
Invariant: [in_cc t do_cc t => forall u subterm t, in_cc u] *) Invariant: [in_cc t do_cc t => forall u subterm t, in_cc u] *)
let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t
(* FIXME
let on_merge cc f = cc.on_merge <- f :: cc.on_merge let on_merge cc f = cc.on_merge <- f :: cc.on_merge
let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term
*)
(* find representative, recursively *) (* find representative, recursively *)
let[@unroll 2] rec find_rec (n:node) : repr = let[@unroll 2] rec find_rec (n:node) : repr =
@ -378,28 +399,6 @@ module Make(A: ARG) = struct
(Util.pp_seq ~sep:" " pp_n) (T_tbl.values cc.tbl) (Util.pp_seq ~sep:" " pp_n) (T_tbl.values cc.tbl)
(Util.pp_seq ~sep:" " pp_sig_e) (Sig_tbl.to_seq cc.signatures_tbl) (Util.pp_seq ~sep:" " pp_sig_e) (Sig_tbl.to_seq cc.signatures_tbl)
let th_data_get (_:t) (n:node) (key: _ Key.t) : _ option =
let n = find_ n in
K_map.find key n.n_th_data
(* update data for [n] *)
let th_data_add (type a) (self:t) (n:node) (key: (term,a) Key.t) (v:a) : unit =
let n = find_ n in
let map = n.n_th_data in
let old_v = K_map.find key map in
let v', is_diff = match old_v with
| None -> v, true
| Some old_v ->
let (module K) = key in
let v' = K.merge old_v v in
v', K.equal v v'
in
if is_diff then (
on_backtrack self (fun () -> n.n_th_data <- map);
);
n.n_th_data <- K_map.add key v' map;
()
(* compute up-to-date signature *) (* compute up-to-date signature *)
let update_sig (s:signature) : Signature.t = let update_sig (s:signature) : Signature.t =
CC_types.map_view s CC_types.map_view s
@ -506,9 +505,9 @@ module Make(A: ARG) = struct
(* TODO: turn this into a fold? *) (* TODO: turn this into a fold? *)
(* decompose explanation [e] of why [n1 = n2] *) (* decompose explanation [e] of why [n1 = n2] *)
let decompose_explain cc (e:explanation) : unit = let rec decompose_explain cc (e:explanation) : unit =
Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e); Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e);
begin match e with match e with
| E_reduction -> () | E_reduction -> ()
| E_congruence (n1, n2) -> | E_congruence (n1, n2) ->
begin match n1.n_sig0, n2.n_sig0 with begin match n1.n_sig0, n2.n_sig0 with
@ -528,12 +527,8 @@ module Make(A: ARG) = struct
assert false assert false
end end
| E_lit lit -> ps_add_lit cc lit | E_lit lit -> ps_add_lit cc lit
| E_lits l -> List.iter (ps_add_lit cc) l
| E_merge (a,b) -> ps_add_obligation cc a b | E_merge (a,b) -> ps_add_obligation cc a b
| E_merges l -> | E_list l -> List.iter (decompose_explain cc) l
(* need to explain each merge in [l] *)
List.iter (fun (t,u) -> ps_add_obligation cc t u) l
end
(* explain why [a = parent_a], where [a -> ... -> parent_a] in the (* explain why [a = parent_a], where [a -> ... -> parent_a] in the
proof forest *) proof forest *)
@ -575,6 +570,7 @@ module Make(A: ARG) = struct
decompose_explain cc e; decompose_explain cc e;
explain_loop cc explain_loop cc
(* FIXME remove
(* add [tag] to [n], indicating that [n] is distinct from all the other (* add [tag] to [n], indicating that [n] is distinct from all the other
nodes tagged with [tag] nodes tagged with [tag]
precond: [n] is a representative *) precond: [n] is a representative *)
@ -585,6 +581,7 @@ module Make(A: ARG) = struct
(fun () -> n.n_tags <- Util.Int_map.remove tag n.n_tags); (fun () -> n.n_tags <- Util.Int_map.remove tag n.n_tags);
n.n_tags <- Util.Int_map.add tag (n,expl) n.n_tags; n.n_tags <- Util.Int_map.add tag (n,expl) n.n_tags;
) )
*)
(* add a term *) (* add a term *)
let [@inline] rec add_term_rec_ cc t : node = let [@inline] rec add_term_rec_ cc t : node =
@ -611,7 +608,16 @@ module Make(A: ARG) = struct
(* [n] might be merged with other equiv classes *) (* [n] might be merged with other equiv classes *)
push_pending cc n; push_pending cc n;
); );
List.iter (fun f -> f cc n t) cc.on_new_term; (* initial theory data *)
let th_map =
IM.fold
(fun _ (module Th: THEORY with type cc=cc) th_map ->
match Th.on_new_term cc t with
| None -> th_map
| Some v -> K_map.add Th.key v th_map)
cc.theories K_map.empty
in
n.n_th_data <- th_map;
n n
(* compute the initial signature of the given node *) (* compute the initial signature of the given node *)
@ -701,7 +707,6 @@ module Make(A: ARG) = struct
(* TODO: remove, once we have moved distinct to a theory *) (* TODO: remove, once we have moved distinct to a theory *)
and[@inline] task_combine_ cc acts = function and[@inline] task_combine_ cc acts = function
| CT_merge (a,b,e_ab) -> task_merge_ cc acts a b e_ab | CT_merge (a,b,e_ab) -> task_merge_ cc acts a b e_ab
| CT_distinct (l,tag,e) -> task_distinct_ cc acts l tag e
(* main CC algo: merge equivalence classes in [st.combine]. (* main CC algo: merge equivalence classes in [st.combine].
@raise Exn_unsat if merge fails *) @raise Exn_unsat if merge fails *)
@ -731,26 +736,6 @@ module Make(A: ARG) = struct
else if size_ ra > size_ rb then rb, ra else if size_ ra > size_ rb then rb, ra
else ra, rb else ra, rb
in in
(* TODO: instead call micro theories, including "distinct" *)
(* update set of tags the new node cannot be equal to *)
let new_tags =
Util.Int_map.union
(fun _i (n1,e1) (n2,e2) ->
(* both maps contain same tag [_i]. conflict clause:
[e1 & e2 & e_ab] impossible *)
Log.debugf 5
(fun k->k "(@[<hv>cc.merge.distinct_conflict@ :tag %d@ \
@[:r1 %a@ :e1 %a@]@ @[:r2 %a@ :e2 %a@]@ :e_ab %a@])"
_i N.pp n1 Expl.pp e1
N.pp n2 Expl.pp e2 Expl.pp e_ab);
let lits = explain_unfold cc e1 in
let lits = explain_unfold ~init:lits cc e2 in
let lits = explain_unfold ~init:lits cc e_ab in
let lits = explain_eq_n ~init:lits cc a n1 in
let lits = explain_eq_n ~init:lits cc b n2 in
raise_conflict cc acts lits)
ra.n_tags rb.n_tags
in
(* when merging terms with [true] or [false], possibly propagate them to SAT *) (* when merging terms with [true] or [false], possibly propagate them to SAT *)
let merge_bool r1 t1 r2 t2 = let merge_bool r1 t1 r2 t2 =
if N.equal r1 (true_ cc) then ( if N.equal r1 (true_ cc) then (
@ -763,6 +748,35 @@ module Make(A: ARG) = struct
merge_bool rb b ra a; merge_bool rb b ra a;
(* perform [union r_from r_into] *) (* perform [union r_from r_into] *)
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into); Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
(* call micro theories *)
begin
let th_into = r_into.n_th_data in
let th_from = r_from.n_th_data in
(* merge the two maps; if a key occurs in both, looks for theories with
this particular key *)
let th =
K_map.merge th_into th_from
~f_both:(fun id pair_into pair_from ->
match IM.find id cc.theories with
| (module Th : THEORY with type cc=t) ->
(* casting magic *)
let (module K) = Th.key in
begin match pair_into, pair_from with
| K.Store v_into, K.Store v_from ->
Log.debugf 15
(fun k->k "(@[cc.merge.th-on-merge@ :th %s@])" K.name);
(* FIXME: explanation is a=ra, e_ab, b=rb *)
Th.on_merge cc r_into v_into r_from v_from e_ab
| _ -> assert false
end
| exception Not_found -> ())
in
(* restore old data, if it changed *)
if not @@ K_map.equal th th_into then (
on_backtrack cc (fun () -> r_into.n_th_data <- th_into);
);
r_into.n_th_data <- th;
end;
begin begin
(* parents might have a different signature, check for collisions *) (* parents might have a different signature, check for collisions *)
N.iter_parents r_from N.iter_parents r_from
@ -773,7 +787,6 @@ module Make(A: ARG) = struct
assert (u.n_root == r_from); assert (u.n_root == r_from);
u.n_root <- r_into); u.n_root <- r_into);
(* now merge the classes *) (* now merge the classes *)
let r_into_old_tags = r_into.n_tags in
let r_into_old_next = r_into.n_next in let r_into_old_next = r_into.n_next in
let r_from_old_next = r_from.n_next in let r_from_old_next = r_from.n_next in
let r_into_old_parents = r_into.n_parents in let r_into_old_parents = r_into.n_parents in
@ -786,11 +799,9 @@ module Make(A: ARG) = struct
N.pp r_from N.pp r_into); N.pp r_from N.pp r_into);
r_into.n_next <- r_into_old_next; r_into.n_next <- r_into_old_next;
r_from.n_next <- r_from_old_next; r_from.n_next <- r_from_old_next;
r_into.n_tags <- r_into_old_tags;
r_into.n_parents <- r_into_old_parents; r_into.n_parents <- r_into_old_parents;
N.iter_class_ r_from (fun u -> u.n_root <- r_from); N.iter_class_ r_from (fun u -> u.n_root <- r_from);
); );
r_into.n_tags <- new_tags;
(* swap [into.next] and [from.next], merging the classes *) (* swap [into.next] and [from.next], merging the classes *)
r_into.n_next <- r_from_old_next; r_into.n_next <- r_from_old_next;
r_from.n_next <- r_into_old_next; r_from.n_next <- r_into_old_next;
@ -810,10 +821,9 @@ module Make(A: ARG) = struct
| _ -> assert false); | _ -> assert false);
a.n_expl <- FL_some {next=b; expl=e_ab}; a.n_expl <- FL_some {next=b; expl=e_ab};
end; end;
(* notify listeners of the merge *)
List.iter (fun f -> f cc r_into r_from e_ab) cc.on_merge
) )
(* FIXME: remove
and task_distinct_ cc acts (l:node list) tag expl : unit = and task_distinct_ cc acts (l:node list) tag expl : unit =
let l = List.map (fun n -> n, find_ n) l in let l = List.map (fun n -> n, find_ n) l in
let coll = let coll =
@ -832,6 +842,7 @@ module Make(A: ARG) = struct
(* put a tag on all equivalence classes, that will make their merge fail *) (* put a tag on all equivalence classes, that will make their merge fail *)
List.iter (fun (_,n) -> add_tag_n cc n tag expl) l List.iter (fun (_,n) -> add_tag_n cc n tag expl) l
end end
*)
(* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1] (* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1]
in the equiv class of [r1] that is a known literal back to the SAT solver in the equiv class of [r1] that is a known literal back to the SAT solver
@ -864,6 +875,61 @@ module Make(A: ARG) = struct
acts.Msat.acts_propagate lit reason acts.Msat.acts_propagate lit reason
| _ -> ()) | _ -> ())
module Theory = struct
type cc = t
type t = theory
type 'a key = (term,lit,'a) Key.t
(* raise a conflict *)
let raise_conflict cc _n1 _n2 expl =
Log.debugf 5
(fun k->k "(@[cc.theory.raise-conflict@ :n1 %a@ :n2 %a@ :expl %a@])"
N.pp _n1 N.pp _n2 Expl.pp expl);
merge_classes cc (true_ cc) (false_ cc) expl
let merge cc n1 n2 expl =
Log.debugf 5
(fun k->k "(@[cc.theory.merge@ :n1 %a@ :n2 %a@ :expl %a@])" N.pp n1 N.pp n2 Expl.pp expl);
merge_classes cc n1 n2 expl
let add_term = add_term
let get_data _cc n key =
assert (N.is_root n);
K_map.find key n.n_th_data
(* FIXME: call micro theory here? in case of merge *)
(* update data for [n] *)
let add_data (type a) (self:cc) (n:node) (key: a key) (v:a) : unit =
let n = find_ n in
let map = n.n_th_data in
let old_v = K_map.find key map in
let v', is_diff = match old_v with
| None -> v, true
| Some old_v ->
let (module K) = key in
let v' = K.merge old_v v in
v', K.equal v v'
in
if is_diff then (
on_backtrack self (fun () -> n.n_th_data <- map);
);
n.n_th_data <- K_map.add key v' map;
()
let make (type a) ~(key:a key) ~on_merge ~on_new_term () : t =
let module Th = struct
type nonrec cc = cc
type data = a
let key = key
let key_id = Key.id key
let on_merge = on_merge
let on_new_term = on_new_term
end in
(module Th : THEORY with type cc=cc)
end
let check_invariants_ (cc:t) = let check_invariants_ (cc:t) =
Log.debug 5 "(cc.check-invariants)"; Log.debug 5 "(cc.check-invariants)";
Log.debugf 15 (fun k-> k "%a" pp_full cc); Log.debugf 15 (fun k-> k "%a" pp_full cc);
@ -943,11 +1009,12 @@ module Make(A: ARG) = struct
Sequence.iter (assert_lit cc) lits Sequence.iter (assert_lit cc) lits
let assert_eq cc t1 t2 (e:lit list) : unit = let assert_eq cc t1 t2 (e:lit list) : unit =
let expl = Expl.mk_lits e in let expl = Expl.mk_list @@ List.rev_map Expl.mk_lit e in
let n1 = add_term cc t1 in let n1 = add_term cc t1 in
let n2 = add_term cc t2 in let n2 = add_term cc t2 in
merge_classes cc n1 n2 expl merge_classes cc n1 n2 expl
(* FIXME: remove
(* generative tag used to annotate classes that can't be merged *) (* generative tag used to annotate classes that can't be merged *)
let distinct_tag_ = ref 0 let distinct_tag_ = ref 0
@ -958,14 +1025,23 @@ module Make(A: ARG) = struct
(fun k->k "(@[cc.assert_distinct@ (@[%a@])@ :tag %d@])" (Util.pp_list T.pp) l tag); (fun k->k "(@[cc.assert_distinct@ (@[%a@])@ :tag %d@])" (Util.pp_list T.pp) l tag);
let l = List.map (add_term cc) l in let l = List.map (add_term cc) l in
Vec.push cc.combine @@ CT_distinct (l, tag, Expl.mk_lit lit) Vec.push cc.combine @@ CT_distinct (l, tag, Expl.mk_lit lit)
*)
let create ?(on_merge=[]) ?(on_new_term=[]) ?(size=`Big) (tst:term_state) : t = let add_th (self:t) (th:theory) : unit =
let (module Th) = th in
if IM.mem Th.key_id self.theories then (
Error.errorf "attempt to add two theories with key %a" Key.pp Th.key
);
Log.debugf 3 (fun k->k "(@[@{<green>cc.add-theory@} %a@])" Key.pp Th.key);
self.theories <- IM.add Th.key_id th self.theories
let create ?th:(theories=[]) ?(size=`Big) (tst:term_state) : t =
let size = match size with `Small -> 128 | `Big -> 2048 in let size = match size with `Small -> 128 | `Big -> 2048 in
let rec cc = { let rec cc = {
tst; tst;
tbl = T_tbl.create size; tbl = T_tbl.create size;
signatures_tbl = Sig_tbl.create size; signatures_tbl = Sig_tbl.create size;
on_merge; on_new_term; theories=IM.empty;
pending=Vec.create(); pending=Vec.create();
combine=Vec.create(); combine=Vec.create();
ps_lits=[]; ps_lits=[];
@ -981,6 +1057,7 @@ module Make(A: ARG) = struct
in in
ignore (Lazy.force true_ : node); ignore (Lazy.force true_ : node);
ignore (Lazy.force false_ : node); ignore (Lazy.force false_ : node);
List.iter (add_th cc) theories; (* now add theories *)
cc cc
let[@inline] find_t cc t : repr = let[@inline] find_t cc t : repr =

View file

@ -2,11 +2,8 @@
module type ARG = Congruence_closure_intf.ARG module type ARG = Congruence_closure_intf.ARG
module type S = Congruence_closure_intf.S module type S = Congruence_closure_intf.S
module type THEORY_DATA = Congruence_closure_intf.THEORY_DATA
module type THEORY_KEY = Congruence_closure_intf.THEORY_KEY module type THEORY_KEY = Congruence_closure_intf.THEORY_KEY
type ('t, 'a) theory_data = ('t,'a) Congruence_closure_intf.theory_data
module Key : THEORY_KEY module Key : THEORY_KEY
module Make(A: ARG) module Make(A: ARG)

View file

@ -1,36 +1,35 @@
module type ARG = CC_types.FULL module type ARG = CC_types.FULL
(** Data stored by a theory for its own terms. module type THEORY_KEY = sig
type ('term,'lit,'a) t
(** An access key for theories which have per-class data ['a] *)
It needs to form a commutative monoid where values can be unmerged upon val create :
backtracking. ?pp:'a Fmt.printer ->
*) name:string ->
module type THEORY_DATA = sig eq:('a -> 'a -> bool) ->
type term merge:('a -> 'a -> 'a) ->
type t unit ->
('term,'lit,'a) t
(** Generative creation of keys for the given theory data.
val empty : t @param eq : Equality. This is used to optimize backtracking info.
val equal : t -> t -> bool @param merge :
(** Equality. This is used to optimize backtracking info. *) [merge d1 d2] is called when merging classes with data [d1] and [d2]
val merge : t -> t -> t
(** [merge d1 d2] is called when merging classes with data [d1] and [d2]
respectively. The theory should already have checked that the merge respectively. The theory should already have checked that the merge
is compatible, and this produces the combined data for terms in the is compatible, and this produces the combined data for terms in the
merged class. *) merged class.
end @param name name of the theory which owns this data
@param pp a printer for the data
*)
type ('t, 'a) theory_data = (module THEORY_DATA with type term = 't and type t = 'a) val equal : ('t,'lit,_) t -> ('t,'lit,_) t -> bool
(** Checks if two keys are equal (generatively) *)
module type THEORY_KEY = sig val pp : _ t Fmt.printer
type ('t, 'a) t (** Prints the name of the key. *)
(** An access key for theories that use terms ['t] and which have
per-class data ['a] *)
val create : ('t, 'a) theory_data -> ('t, 'a) t
(** Generative creation of keys for the given theory data. *)
end end
module type S = sig module type S = sig
@ -87,12 +86,9 @@ module type S = sig
type t type t
val pp : t Fmt.printer val pp : t Fmt.printer
val mk_reduction : t
val mk_congruence : N.t -> N.t -> t
val mk_merge : N.t -> N.t -> t val mk_merge : N.t -> N.t -> t
val mk_merges : (N.t * N.t) list -> t
val mk_lit : lit -> t val mk_lit : lit -> t
val mk_lits : lit list -> t val mk_list : t list -> t
end end
type node = N.t type node = N.t
@ -119,34 +115,52 @@ module type S = sig
(** Actions available to the theory *) (** Actions available to the theory *)
type sat_actions = (Msat.void, lit, Msat.void, proof) Msat.acts type sat_actions = (Msat.void, lit, Msat.void, proof) Msat.acts
module Theory : sig
type cc = t
type t
type 'a key = (term,lit,'a) Key.t
val raise_conflict : cc -> Expl.t -> unit
(** Raise a conflict with the given explanation
it must be a theory tautology that [expl ==> absurd].
To be used in theories. *)
val merge : cc -> N.t -> N.t -> Expl.t -> unit
(** Merge these two nodes given this explanation.
It must be a theory tautology that [expl ==> n1 = n2].
To be used in theories. *)
val add_term : cc -> term -> N.t
(** Add/retrieve node for this term.
To be used in theories *)
val get_data : cc -> N.t -> 'a key -> 'a option
(** Get data information for this particular representative *)
val add_data : cc -> N.t -> 'a key -> 'a -> unit
(** Add data to this particular representative. Will be backtracked. *)
val make :
key:'a key ->
on_merge:(cc -> N.t -> 'a -> N.t -> 'a -> Expl.t -> unit) ->
on_new_term:(cc -> term -> 'a option) ->
unit ->
t
(** Build a micro theory. It can use the callbacks above. *)
end
val create : val create :
?on_merge:(t -> repr -> repr -> explanation -> unit) list -> ?th:Theory.t list ->
?on_new_term:(t -> repr -> term -> unit) list ->
?size:[`Small | `Big] -> ?size:[`Small | `Big] ->
term_state -> term_state ->
t t
(** Create a new congruence closure. *) (** Create a new congruence closure. *)
val on_merge : t -> (t -> repr -> repr -> explanation -> unit) -> unit val add_th : t -> Theory.t -> unit
(** Add a callback, to be called whenever two classes are merged *) (** Add a (micro) theory to the congruence closure.
@raise Error.Error if there is already a theory with
val on_new_term : t -> (t -> repr -> term -> unit) -> unit the same key. *)
(** Add a callback, to be called whenever a node is added *)
val merge_classes : t -> node -> node -> explanation -> unit
(** Merge the two given nodes with given explanation.
It must be a theory tautology that [expl ==> n1 = n2] *)
val th_data_get : t -> N.t -> (term, 'a) Key.t -> 'a option
(** Get data information for this particular representative *)
val th_data_add : t -> N.t -> (term, 'a) Key.t -> 'a -> unit
(** Add the given data to this node (or rather, to its representative).
This will be backtracked. *)
(* TODO: merge true/false?
val raise_conflict : CC.t -> CC.N.t -> CC.N.t -> Expl.t -> 'a
*)
val set_as_lit : t -> N.t -> lit -> unit val set_as_lit : t -> N.t -> lit -> unit
(** map the given node to a literal. *) (** map the given node to a literal. *)
@ -173,11 +187,12 @@ module type S = sig
val assert_eq : t -> term -> term -> lit list -> unit val assert_eq : t -> term -> term -> lit list -> unit
(** merge the given terms with some explanations *) (** merge the given terms with some explanations *)
(* TODO: remove and move into its own library as a micro theory *) (* TODO: remove and move into its own library as a micro theory
val assert_distinct : t -> term list -> neq:term -> lit -> unit val assert_distinct : t -> term list -> neq:term -> lit -> unit
(** [assert_distinct l ~neq:u e] asserts all elements of [l] are distinct (** [assert_distinct l ~neq:u e] asserts all elements of [l] are distinct
because [lit] is true because [lit] is true
precond: [u = distinct l] *) precond: [u = distinct l] *)
*)
val check : t -> sat_actions -> unit val check : t -> sat_actions -> unit
(** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc. (** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc.

View file

@ -16,6 +16,7 @@ module type S = Congruence_closure.S
module Mini_cc = Mini_cc module Mini_cc = Mini_cc
module Congruence_closure = Congruence_closure module Congruence_closure = Congruence_closure
module Key = Congruence_closure.Key
module Make = Congruence_closure.Make module Make = Congruence_closure.Make

View file

@ -150,7 +150,7 @@ let eval (m:t) (t:Term.t) : Value.t option =
let b = aux b in let b = aux b in
if Value.equal a b then Value.true_ else Value.false_ if Value.equal a b then Value.true_ else Value.false_
| App_cst (c, args) -> | App_cst (c, args) ->
begin try Term.Map.find t m.values try Term.Map.find t m.values
with Not_found -> with Not_found ->
match Cst.view c with match Cst.view c with
| Cst_def udef -> | Cst_def udef ->
@ -168,7 +168,6 @@ let eval (m:t) (t:Term.t) : Value.t option =
| exception Not_found -> | exception Not_found ->
raise No_value (* no particular interpretation *) raise No_value (* no particular interpretation *)
end end
end
in in
try Some (aux t) try Some (aux t)
with No_value -> None with No_value -> None

View file

@ -13,9 +13,12 @@ module Lit = Lit
module Theory_combine = Theory_combine module Theory_combine = Theory_combine
module Theory = Theory module Theory = Theory
module Solver = Solver module Solver = Solver
module CC = CC
module Solver_types = Solver_types module Solver_types = Solver_types
type theory = Theory.t
(**/**) (**/**)
module Vec = Msat.Vec module Vec = Msat.Vec
module Log = Msat.Log module Log = Msat.Log

View file

@ -175,9 +175,10 @@ let assume (self:t) (c:Lit.t IArray.t) : unit =
let c = IArray.to_array_map (Sat_solver.make_atom sat) c in let c = IArray.to_array_map (Sat_solver.make_atom sat) c in
Sat_solver.add_clause_a sat c Proof_default Sat_solver.add_clause_a sat c Proof_default
(* TODO: remove? use a special constant + micro theory instead? *) (* TODO: remove? use a special constant + micro theory instead?
let[@inline] assume_distinct self l ~neq lit : unit = let[@inline] assume_distinct self l ~neq lit : unit =
CC.assert_distinct (cc self) l lit ~neq CC.assert_distinct (cc self) l lit ~neq
*)
let check_model (_s:t) : unit = let check_model (_s:t) : unit =
Log.debug 1 "(smt.solver.check-model)"; Log.debug 1 "(smt.solver.check-model)";

View file

@ -55,7 +55,9 @@ val mk_atom_t : t -> ?sign:bool -> Term.t -> Atom.t
val assume : t -> Lit.t IArray.t -> unit val assume : t -> Lit.t IArray.t -> unit
(* TODO: use the theory instead
val assume_distinct : t -> Term.t list -> neq:Term.t -> Lit.t -> unit val assume_distinct : t -> Term.t list -> neq:Term.t -> Lit.t -> unit
*)
val solve : val solve :
?on_exit:(unit -> unit) list -> ?on_exit:(unit -> unit) list ->

View file

@ -23,16 +23,11 @@ module CC_expl = CC.Expl
(** Actions available to a theory during its lifetime *) (** Actions available to a theory during its lifetime *)
module type ACTIONS = sig module type ACTIONS = sig
val cc : CC.t
val raise_conflict: conflict -> 'a val raise_conflict: conflict -> 'a
(** Give a conflict clause to the solver *) (** Give a conflict clause to the solver *)
val propagate_eq: Term.t -> Term.t -> Lit.t list -> unit
(** Propagate an equality [t = u] because [e].
TODO: use [CC.Expl] instead, with lit/merge constructors *)
val propagate_distinct: Term.t list -> neq:Term.t -> Lit.t -> unit
(** Propagate a [distinct l] because [e] (where [e = neq] *)
val propagate: Lit.t -> (unit -> Lit.t list) -> unit val propagate: Lit.t -> (unit -> Lit.t list) -> unit
(** Propagate a boolean using a unit clause. (** Propagate a boolean using a unit clause.
[expl => lit] must be a theory lemma, that is, a T-tautology *) [expl => lit] must be a theory lemma, that is, a T-tautology *)
@ -48,16 +43,6 @@ module type ACTIONS = sig
val add_persistent_axiom: Lit.t list -> unit val add_persistent_axiom: Lit.t list -> unit
(** Add toplevel clause to the SAT solver. This clause will (** Add toplevel clause to the SAT solver. This clause will
not be backtracked. *) not be backtracked. *)
val cc_add_term: Term.t -> CC_eq_class.t
(** add/get term to the congruence closure *)
val cc_find: CC_eq_class.t -> CC_eq_class.t
(** Find representative of this in the congruence closure *)
val cc_all_classes: CC_eq_class.t Sequence.t
(** All current equivalence classes
(caution: linear in the number of terms existing in the congruence closure) *)
end end
type actions = (module ACTIONS) type actions = (module ACTIONS)

View file

@ -51,9 +51,8 @@ let assert_lits_ ~final (self:t) acts (lits:Lit.t Sequence.t) : unit =
(* transmit to theories. *) (* transmit to theories. *)
CC.check cc acts; CC.check cc acts;
let module A = struct let module A = struct
let cc = cc
let[@inline] raise_conflict c : 'a = acts.Msat.acts_raise_conflict c Proof_default let[@inline] raise_conflict c : 'a = acts.Msat.acts_raise_conflict c Proof_default
let[@inline] propagate_eq t u expl : unit = CC.assert_eq cc t u expl
let propagate_distinct ts ~neq expl = CC.assert_distinct cc ts ~neq expl
let[@inline] propagate p cs : unit = let[@inline] propagate p cs : unit =
acts.Msat.acts_propagate p (Msat.Consequence (fun () -> cs(), Proof_default)) acts.Msat.acts_propagate p (Msat.Consequence (fun () -> cs(), Proof_default))
let[@inline] propagate_l p cs : unit = propagate p (fun()->cs) let[@inline] propagate_l p cs : unit = propagate p (fun()->cs)
@ -61,9 +60,6 @@ let assert_lits_ ~final (self:t) acts (lits:Lit.t Sequence.t) : unit =
acts.Msat.acts_add_clause ~keep:false lits Proof_default acts.Msat.acts_add_clause ~keep:false lits Proof_default
let[@inline] add_persistent_axiom lits : unit = let[@inline] add_persistent_axiom lits : unit =
acts.Msat.acts_add_clause ~keep:true lits Proof_default acts.Msat.acts_add_clause ~keep:true lits Proof_default
let[@inline] cc_add_term t = CC.add_term cc t
let[@inline] cc_find t = CC.find cc t
let cc_all_classes = CC.all_classes cc
end in end in
let acts = (module A : Theory.ACTIONS) in let acts = (module A : Theory.ACTIONS) in
theories self theories self

View file

@ -8,6 +8,7 @@ type 'a or_error = ('a, string) CCResult.t
module E = CCResult module E = CCResult
module A = Ast module A = Ast
module Form = Sidekick_th_bool.Bool_term module Form = Sidekick_th_bool.Bool_term
module Distinct = Sidekick_th_distinct
module Fmt = CCFormat module Fmt = CCFormat
module Dot = Msat_backend.Dot.Make(Solver.Sat_solver)(Msat_backend.Dot.Default(Solver.Sat_solver)) module Dot = Msat_backend.Dot.Make(Solver.Sat_solver)(Msat_backend.Dot.Default(Solver.Sat_solver))
@ -137,7 +138,7 @@ module Conv = struct
in in
Form.and_l tst (curry_eq l) Form.and_l tst (curry_eq l)
| A.Op (A.Distinct, l) -> | A.Op (A.Distinct, l) ->
Form.distinct_l tst @@ List.map (aux subst) l Distinct.distinct_l tst @@ List.map (aux subst) l
| A.Not f -> Form.not_ tst (aux subst f) | A.Not f -> Form.not_ tst (aux subst f)
| A.Bool true -> Term.true_ tst | A.Bool true -> Term.true_ tst
| A.Bool false -> Term.false_ tst | A.Bool false -> Term.false_ tst

View file

@ -3,12 +3,8 @@
(name sidekick_smtlib) (name sidekick_smtlib)
(public_name sidekick.smtlib) (public_name sidekick.smtlib)
(libraries containers zarith msat sidekick.smt sidekick.util (libraries containers zarith msat sidekick.smt sidekick.util
sidekick.smt.th-bool msat.backend) sidekick.smt.th-bool sidekick.smt.th-distinct msat.backend)
(flags :standard -w +a-4-42-44-48-50-58-32-60@8 (flags :standard -open Sidekick_util))
-safe-string -color always -open Sidekick_util)
(ocamlopt_flags :standard -O3 -color always -bin-annot
-unbox-closures -unbox-closures-factor 20)
)
(menhir (modules Parser)) (menhir (modules Parser))

View file

@ -3,11 +3,9 @@
type 'a view = type 'a view =
| B_not of 'a | B_not of 'a
| B_eq of 'a * 'a
| B_and of 'a IArray.t | B_and of 'a IArray.t
| B_or of 'a IArray.t | B_or of 'a IArray.t
| B_imply of 'a IArray.t * 'a | B_imply of 'a IArray.t * 'a
| B_distinct of 'a IArray.t
| B_atom of 'a | B_atom of 'a
(** {2 Interface for a representation of boolean terms} *) (** {2 Interface for a representation of boolean terms} *)

View file

@ -18,7 +18,6 @@ let id_not = ID.make "not"
let id_and = ID.make "and" let id_and = ID.make "and"
let id_or = ID.make "or" let id_or = ID.make "or"
let id_imply = ID.make "=>" let id_imply = ID.make "=>"
let id_distinct = ID.make "distinct"
let equal = T.equal let equal = T.equal
let hash = T.hash let hash = T.hash
@ -34,17 +33,14 @@ let view_id cst_id args =
(* conclusion is stored first *) (* conclusion is stored first *)
let len = IArray.length args in let len = IArray.length args in
B_imply (IArray.sub args 1 (len-1), IArray.get args 0) B_imply (IArray.sub args 1 (len-1), IArray.get args 0)
) else if ID.equal cst_id id_distinct then (
B_distinct args
) else ( ) else (
raise_notrace Not_a_th_term raise_notrace Not_a_th_term
) )
let view_as_bool (t:T.t) : T.t view = let view_as_bool (t:T.t) : T.t view =
match T.view t with match T.view t with
| Eq (a,b) -> B_eq (a,b)
| App_cst ({cst_id; _}, args) -> | App_cst ({cst_id; _}, args) ->
begin try view_id cst_id args with Not_a_th_term -> B_atom t end (try view_id cst_id args with Not_a_th_term -> B_atom t)
| _ -> B_atom t | _ -> B_atom t
module C = struct module C = struct
@ -69,14 +65,7 @@ module C = struct
| B_imply (_, V_bool true) -> Value.true_ | B_imply (_, V_bool true) -> Value.true_
| B_imply (a,_) when IArray.exists Value.is_false a -> Value.true_ | B_imply (a,_) when IArray.exists Value.is_false a -> Value.true_
| B_imply (a,b) when IArray.for_all Value.is_bool a && Value.is_bool b -> Value.false_ | B_imply (a,b) when IArray.for_all Value.is_bool a && Value.is_bool b -> Value.false_
| B_eq (a,b) -> Value.bool @@ Value.equal a b
| B_atom v -> v | B_atom v -> v
| B_distinct a ->
if
Sequence.diagonal (IArray.to_seq a)
|> Sequence.for_all (fun (x,y) -> not @@ Value.equal x y)
then Value.true_
else Value.false_
| B_not _ | B_and _ | B_or _ | B_imply _ | B_not _ | B_and _ | B_or _ | B_imply _
-> Error.errorf "non boolean value in boolean connective" -> Error.errorf "non boolean value in boolean connective"
@ -92,7 +81,6 @@ module C = struct
let and_ = mk_cst id_and let and_ = mk_cst id_and
let or_ = mk_cst id_or let or_ = mk_cst id_or
let imply = mk_cst id_imply let imply = mk_cst id_imply
let distinct = mk_cst id_distinct
end end
let as_id id (t:T.t) : T.t IArray.t option = let as_id id (t:T.t) : T.t IArray.t option =
@ -152,20 +140,9 @@ let imply_l st xs y = match xs with
let imply st a b = imply_a st (IArray.singleton a) b let imply st a b = imply_a st (IArray.singleton a) b
let distinct st a =
if IArray.length a <= 1
then T.true_ st
else T.app_cst st C.distinct a
let distinct_l st = function
| [] | [_] -> T.true_ st
| xs -> distinct st (IArray.of_list xs)
let make st = function let make st = function
| B_atom t -> t | B_atom t -> t
| B_eq (a,b) -> T.eq st a b
| B_and l -> and_a st l | B_and l -> and_a st l
| B_or l -> or_a st l | B_or l -> or_a st l
| B_imply (a,b) -> imply_a st a b | B_imply (a,b) -> imply_a st a b
| B_not t -> not_ st t | B_not t -> not_ st t
| B_distinct l -> distinct st l

View file

@ -15,8 +15,6 @@ val imply_a : state -> term IArray.t -> term -> term
val imply_l : state -> term list -> term -> term val imply_l : state -> term list -> term -> term
val eq : state -> term -> term -> term val eq : state -> term -> term -> term
val neq : state -> term -> term -> term val neq : state -> term -> term -> term
val distinct : state -> term IArray.t -> term
val distinct_l : state -> term list -> term
val and_a : state -> term IArray.t -> term val and_a : state -> term IArray.t -> term
val and_l : state -> term list -> term val and_l : state -> term list -> term
val or_a : state -> term IArray.t -> term val or_a : state -> term IArray.t -> term

View file

@ -9,11 +9,9 @@ module Th_dyn_tseitin = Th_dyn_tseitin
type 'a view = 'a Intf.view = type 'a view = 'a Intf.view =
| B_not of 'a | B_not of 'a
| B_eq of 'a * 'a
| B_and of 'a IArray.t | B_and of 'a IArray.t
| B_or of 'a IArray.t | B_or of 'a IArray.t
| B_imply of 'a IArray.t * 'a | B_imply of 'a IArray.t * 'a
| B_distinct of 'a IArray.t
| B_atom of 'a | B_atom of 'a
module type BOOL_TERM = Intf.BOOL_TERM module type BOOL_TERM = Intf.BOOL_TERM

View file

@ -15,14 +15,7 @@ module Make(Term : ARG) = struct
type term = Term.t type term = Term.t
module T_tbl = CCHashtbl.Make(Term) module T_tbl = CCHashtbl.Make(Term)
module Lit = Sidekick_smt.Lit
module Lit = struct
include Sidekick_smt.Lit
let eq tst a b = atom tst ~sign:true @@ Term.make tst (B_eq (a,b))
let neq tst a b = neg @@ eq tst a b
end
let pp_c out c = Fmt.fprintf out "(@[<hv>%a@])" (Util.pp_list Lit.pp) c
type t = { type t = {
tst: Term.state; tst: Term.state;
@ -39,22 +32,7 @@ module Make(Term : ARG) = struct
in in
match v with match v with
| B_not _ -> assert false (* normalized *) | B_not _ -> assert false (* normalized *)
| B_atom _ | B_eq _ -> () (* CC will manage *) | B_atom _ -> () (* CC will manage *)
| B_distinct l ->
let l = IArray.to_list l in
if Lit.sign lit then (
A.propagate_distinct l ~neq:lit_t lit
) else if final && not @@ expanded () then (
(* add clause [distinct t1…tn _{i,j>i} t_i=j] *)
let c =
Sequence.diagonal_l l
|> Sequence.map (fun (t,u) -> Lit.eq self.tst t u)
|> Sequence.to_rev_list
in
let c = Lit.neg lit :: c in
Log.debugf 5 (fun k->k "(@[tseitin.distinct.case-split@ %a@])" pp_c c);
add_axiom c
)
| B_and subs -> | B_and subs ->
if Lit.sign lit then ( if Lit.sign lit then (
(* propagate [lit => subs_i] *) (* propagate [lit => subs_i] *)
@ -105,7 +83,7 @@ module Make(Term : ARG) = struct
(fun lit -> (fun lit ->
let t = Lit.term lit in let t = Lit.term lit in
match Term.view_as_bool t with match Term.view_as_bool t with
| B_atom _ | B_eq _ -> () | B_atom _ -> ()
| v -> tseitin ~final self acts lit t v) | v -> tseitin ~final self acts lit t v)
let partial_check (self:t) acts (lits:Lit.t Sequence.t) = let partial_check (self:t) acts (lits:Lit.t Sequence.t) =

View file

@ -12,11 +12,5 @@ module type ARG = Bool_intf.BOOL_TERM
module Make(Term : ARG) : sig module Make(Term : ARG) : sig
type term = Term.t type term = Term.t
module Lit : sig
type t = Sidekick_smt.Lit.t
val eq : Term.state -> term -> term -> t
val neq : Term.state -> term -> term -> t
end
val th : Sidekick_smt.Theory.t val th : Sidekick_smt.Theory.t
end end

View file

@ -2,8 +2,5 @@
(name Sidekick_th_bool) (name Sidekick_th_bool)
(public_name sidekick.smt.th-bool) (public_name sidekick.smt.th-bool)
(libraries containers sidekick.smt) (libraries containers sidekick.smt)
(flags :standard -w +a-4-44-48-58-60@8 (flags :standard -open Sidekick_util))
-color always -safe-string -short-paths -open Sidekick_util)
(ocamlopt_flags :standard -O3 -color always
-unbox-closures -unbox-closures-factor 20))