refactor types for terms and congruence closure

- terms are extensible
- explanations have a custom case, shaped as a term
- remove distinction repr/node in Equiv_class, for simplicity
- make propositional connectives n-ary
This commit is contained in:
Simon Cruanes 2018-01-30 21:55:37 -06:00
parent 2aab43f95d
commit 50fe488dcb
11 changed files with 258 additions and 207 deletions

View file

@ -87,6 +87,7 @@ module type S = sig
(** Add the list of clauses to the current set of assumptions.
Modifies the sat solver state in place. *)
(* TODO: provide a local, backtrackable version *)
val add_clause : t -> clause -> unit
(** Lower level addition of clauses *)

View file

@ -3,7 +3,7 @@ open CDCL
open Solver_types
type node = Equiv_class.t
type repr = Equiv_class.repr
type repr = Equiv_class.t
(** A signature is a shallow term shape where immediate subterms
are representative *)
@ -14,12 +14,12 @@ end
module Sig_tbl = CCHashtbl.Make(Signature)
type merge_op = node * node * cc_explanation
type merge_op = node * node * explanation
(* a merge operation to perform *)
type actions =
| Propagate of Lit.t * cc_explanation list
| Split of Lit.t list * cc_explanation list
| Propagate of Lit.t * explanation list
| Split of Lit.t list * explanation list
| Merge of node * node (* merge these two classes *)
type t = {
@ -38,7 +38,7 @@ type t = {
(* register a function to be called when we backtrack *)
at_lvl_0: unit -> bool;
(* currently at level 0? *)
on_merge: (repr -> repr -> cc_explanation -> unit) list;
on_merge: (repr -> repr -> explanation -> unit) list;
(* callbacks to call when we merge classes *)
pending: node Vec.t;
(* nodes to check, maybe their new signature is in {!signatures_tbl} *)
@ -59,11 +59,6 @@ type t = {
several times.
See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *)
module CC_expl_set = CCSet.Make(struct
type t = cc_explanation
let compare = Solver_types.cmp_cc_expl
end)
let[@inline] is_root_ (n:node) : bool = n.n_root == n
let[@inline] size_ (r:repr) =
@ -78,7 +73,7 @@ let[@inline] mem (cc:t) (t:term): bool =
(* find representative, recursively, and perform path compression *)
let rec find_rec cc (n:node) : repr =
if n==n.n_root then (
Equiv_class.unsafe_repr_of_node n
n
) else (
let old_root = n.n_root in
let root = find_rec cc old_root in
@ -104,19 +99,20 @@ let[@inline] get_ cc (t:term) : node =
(* non-recursive, inlinable function for [find] *)
let[@inline] find st (n:node) : repr =
if n == n.n_root
then (Equiv_class.unsafe_repr_of_node n)
else find_rec st n
if n == n.n_root then n else find_rec st n
let[@inline] find_tn cc (t:term) : repr = get_ cc t |> find cc
let[@inline] find_tt cc (t:term) : term = find_tn cc t |> Equiv_class.Repr.term
let[@inline] find_tt cc (t:term) : term = find_tn cc t |> Equiv_class.term
let[@inline] same_class cc (n1:node)(n2:node): bool =
Equiv_class.Repr.equal (find cc n1) (find cc n2)
Equiv_class.equal (find cc n1) (find cc n2)
let[@inline] same_class_t cc (t1:term)(t2:term): bool =
Equiv_class.equal (find_tn cc t1) (find_tn cc t2)
(* compute signature *)
let signature cc (t:term): node term_cell option =
let find = (find_tn cc :> term -> node) in
let find = find_tn cc in
begin match Term.cell t with
| True | Builtin _
-> None
@ -124,6 +120,8 @@ let signature cc (t:term): node term_cell option =
| App_cst (f, a) -> App_cst (f, IArray.map find a) |> CCOpt.return
| If (a,b,c) -> If (find a, get_ cc b, get_ cc c) |> CCOpt.return
| Case (t, m) -> Case (find t, ID.Map.map (get_ cc) m) |> CCOpt.return
| Custom {view;tc} ->
Custom {tc; view=tc.tc_t_subst find view} |> CCOpt.return
end
(* find whether the given (parent) term corresponds to some signature
@ -151,7 +149,7 @@ let add_signature cc (t:term) (r:repr): unit = match signature cc t with
);
Sig_tbl.add cc.signatures_tbl s r;
| Some r' ->
assert (Equiv_class.Repr.equal r r');
assert (Equiv_class.equal r r');
end
let is_done (cc:t): bool =
@ -165,24 +163,24 @@ let push_pending cc t : unit =
let push_combine cc t u e : unit =
Log.debugf 5
(fun k->k "(@[<hv1>push_combine@ %a@ %a@ expl: %a@])"
Equiv_class.pp t Equiv_class.pp u pp_cc_explanation e);
Equiv_class.pp t Equiv_class.pp u Explanation.pp e);
Vec.push cc.combine (t,u,e)
let push_split cc (lits:lit list) (expl:cc_explanation list): unit =
let push_split cc (lits:lit list) (expl:explanation list): unit =
Log.debugf 5
(fun k->k "(@[<hv1>push_split@ (@[%a@])@ expl: (@[<hv>%a@])@])"
(Util.pp_list Lit.pp) lits (Util.pp_list pp_cc_explanation) expl);
(Util.pp_list Lit.pp) lits (Util.pp_list Explanation.pp) expl);
let l = Split (lits, expl) in
cc.actions <- l :: cc.actions
let push_propagation cc (lit:lit) (expl:cc_explanation list): unit =
let push_propagation cc (lit:lit) (expl:explanation list): unit =
Log.debugf 5
(fun k->k "(@[<hv1>push_propagate@ %a@ expl: (@[<hv>%a@])@])"
Lit.pp lit (Util.pp_list pp_cc_explanation) expl);
Lit.pp lit (Util.pp_list Explanation.pp) expl);
let l = Propagate (lit,expl) in
cc.actions <- l :: cc.actions
let[@inline] union cc (a:node) (b:node) (e:cc_explanation): unit =
let[@inline] union cc (a:node) (b:node) (e:explanation): unit =
if not (same_class cc a b) then (
push_combine cc a b e; (* start by merging [a=b] *)
)
@ -196,11 +194,11 @@ let rec reroot_expl cc (n:node): unit =
cc.on_backtrack (fun () -> n.n_expl <- old_expl);
);
begin match old_expl with
| None -> () (* already root *)
| Some (u, e_n_u) ->
| E_none -> () (* already root *)
| E_some {next=u; expl=e_n_u} ->
reroot_expl cc u;
u.n_expl <- Some (n, e_n_u);
n.n_expl <- None;
u.n_expl <- E_some {next=n; expl=e_n_u};
n.n_expl <- E_none;
end
(* TODO:
@ -208,19 +206,18 @@ let rec reroot_expl cc (n:node): unit =
- also, obtain merges of CC via callbacks / [pop_merges] afterwards?
*)
exception Exn_unsat of cc_explanation list
exception Exn_unsat of explanation Bag.t
let unsat (e:cc_explanation list): _ = raise (Exn_unsat e)
let unsat (e:explanation Bag.t): _ = raise (Exn_unsat e)
type result =
| Sat of actions list
| Unsat of cc_explanation list
| Unsat of explanation Bag.t
(* list of direct explanations to the conflict. *)
let[@inline] all_classes cc : repr Sequence.t =
Term.Tbl.values cc.tbl
|> Sequence.filter is_root_
|> Equiv_class.unsafe_repr_seq_of_seq
(* main CC algo: add terms from [pending] to the signature table,
check for collisions *)
@ -236,7 +233,7 @@ let rec update_pending (cc:t): result =
add_signature cc n.n_term (find cc n)
| Some u ->
(* must combine [t] with [r] *)
push_combine cc n (u:>node) (CC_congruence (n,(u:>node)))
push_combine cc n u(E_congruence (n,u))
end;
(* FIXME: when to actually evaluate?
eval_pending cc;
@ -257,8 +254,8 @@ and update_combine cc =
let a, b, e_ab = Vec.pop_last cc.combine in
let ra = find cc a in
let rb = find cc b in
if not (Equiv_class.Repr.equal ra rb) then (
assert (is_root_ (ra:>node));
if not (Equiv_class.equal ra rb) then (
assert (is_root_ ra);
assert (is_root_ (rb:>node));
(* We will merge [r_from] into [r_into].
we try to ensure that [size ra <= size rb] in general, unless
@ -296,11 +293,11 @@ and update_combine cc =
(* update explanations (a -> b), arbitrarily *)
begin
reroot_expl cc a;
assert (a.n_expl = None);
assert (a.n_expl = E_none);
if not (cc.at_lvl_0 ()) then (
cc.on_backtrack (fun () -> a.n_expl <- None);
cc.on_backtrack (fun () -> a.n_expl <- E_none);
);
a.n_expl <- Some (b, e_ab);
a.n_expl <- E_some {next=b; expl=e_ab};
end;
(* notify listeners of the merge *)
notify_merge cc r_from ~into:r_into e_ab;
@ -312,7 +309,7 @@ and update_combine cc =
(* Checks if [ra] and [~into] have compatible normal forms and can
be merged w.r.t. the theories.
Side effect: also pushes sub-tasks *)
and notify_merge cc (ra:repr) ~into:(rb:repr) (e:cc_explanation): unit =
and notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit =
assert (is_root_ (ra:>node));
assert (is_root_ (rb:>node));
List.iter
@ -366,6 +363,7 @@ and add_new_term cc (t:term) : node =
add_sub_t c
| Case (u, _) -> add_sub_t u
| Builtin b -> Term.builtin_to_seq b add_sub_t
| Custom {view;tc} -> tc.tc_t_sub view add_sub_t
end;
(* remove term when we backtrack *)
if not (cc.at_lvl_0 ()) then (
@ -399,7 +397,7 @@ let assert_lit cc lit : unit = match Lit.view lit with
(* equate t and true/false *)
let rhs = if sign then true_ cc else false_ cc in
let n = add cc t in
push_combine cc n rhs (CC_lit lit);
push_combine cc n rhs (E_lit lit);
()
let create ?(size=2048) ~on_backtrack ~at_lvl_0 ~on_merge (tst:Term.state) : t =
@ -413,7 +411,7 @@ let create ?(size=2048) ~on_backtrack ~at_lvl_0 ~on_merge (tst:Term.state) : t =
on_backtrack;
at_lvl_0;
pending=Vec.make_empty Equiv_class.dummy;
combine= Vec.make_empty (nd,nd,CC_reduce_eq(nd,nd));
combine= Vec.make_empty (nd,nd,E_reduce_eq(nd,nd));
actions=[];
ps_lits=Lit.Set.empty;
ps_queue=Vec.make_empty (nd,nd);
@ -426,8 +424,8 @@ let create ?(size=2048) ~on_backtrack ~at_lvl_0 ~on_merge (tst:Term.state) : t =
(* distance from [t] to its root in the proof forest *)
let[@inline][@unroll 2] rec distance_to_root (n:node): int = match n.n_expl with
| None -> 0
| Some (t', _) -> 1 + distance_to_root t'
| E_none -> 0
| E_some {next=t'; _} -> 1 + distance_to_root t'
(* find the closest common ancestor of [a] and [b] in the proof forest *)
let find_common_ancestor (a:node) (b:node) : node =
@ -437,8 +435,8 @@ let find_common_ancestor (a:node) (b:node) : node =
let rec drop_ n t =
if n=0 then t
else match t.n_expl with
| None -> assert false
| Some (t', _) -> drop_ (n-1) t'
| E_none -> assert false
| E_some {next=t'; _} -> drop_ (n-1) t'
in
(* reduce to the problem where [a] and [b] have the same distance to root *)
let a, b =
@ -450,18 +448,13 @@ let find_common_ancestor (a:node) (b:node) : node =
let rec aux_same_dist a b =
if a==b then a
else match a.n_expl, b.n_expl with
| None, _ | _, None -> assert false
| Some (a', _), Some (b', _) -> aux_same_dist a' b'
| E_none, _ | _, E_none -> assert false
| E_some {next=a'; _}, E_some {next=b'; _} -> aux_same_dist a' b'
in
aux_same_dist a b
let[@inline] ps_add_obligation (cc:t) a b = Vec.push cc.ps_queue (a,b)
let[@inline] ps_add_lit ps l = ps.ps_lits <- Lit.Set.add l ps.ps_lits
let[@inline] ps_add_expl ps e = match e with
| CC_lit lit -> ps_add_lit ps lit
| CC_reduce_eq _ | CC_congruence _
| CC_injectivity _ | CC_reduction
-> ()
and ps_add_obligation_t cc (t1:term) (t2:term) =
let n1 = get_ cc t1 in
@ -473,41 +466,38 @@ let ps_clear (cc:t) =
Vec.clear cc.ps_queue;
()
let decompose_explain cc (e:cc_explanation): unit =
Log.debugf 5 (fun k->k "(@[decompose_expl@ %a@])" pp_cc_explanation e);
ps_add_expl cc e;
let rec decompose_explain cc (e:explanation): unit =
Log.debugf 5 (fun k->k "(@[decompose_expl@ %a@])" Explanation.pp e);
begin match e with
| CC_reduction
| CC_lit _ -> ()
| CC_reduce_eq (a, b) ->
| E_reduction -> ()
| E_lit lit -> ps_add_lit cc lit
| E_custom {args;_} ->
(* decompose sub-expls *)
List.iter (decompose_explain cc) args
| E_reduce_eq (a, b) ->
ps_add_obligation cc a b;
| CC_injectivity (t1,t2)
(* FIXME: should this be different from CC_congruence? just explain why t1==t2? *)
| CC_congruence (t1,t2) ->
| E_injectivity (t1,t2) ->
(* arguments of [t1], [t2] are equal by injectivity, so we
just need to explain why [t1=t2] *)
ps_add_obligation cc t1 t2
| E_congruence (t1,t2) ->
(* [t1] and [t2] must be applications of the same symbol to
arguments that are pairwise equal *)
begin match t1.n_term.term_cell, t2.n_term.term_cell with
| True, _ -> assert false (* no congruence here *)
| App_cst (f1, a1), App_cst (f2, a2) ->
assert (Cst.equal f1 f2);
assert (IArray.length a1 = IArray.length a2);
IArray.iter2 (ps_add_obligation_t cc) a1 a2
| Case (_t1, _m1), Case (_t2, _m2) ->
assert false
(* TODO: this should never happen
ps_add_obligation ps t1 t2;
ID.Map.iter
(fun id rhs1 ->
let rhs2 = ID.Map.find id m2 in
ps_add_obligation ps rhs1 rhs2)
m1;
*)
| If (a1,b1,c1), If (a2,b2,c2) ->
ps_add_obligation_t cc a1 a2;
ps_add_obligation_t cc b1 b2;
ps_add_obligation_t cc c1 c2;
| Builtin _, _ -> assert false
| Custom r1, Custom r2 ->
(* ask the theory to explain why [r1 = r2] *)
let l = r1.tc.tc_t_explain (same_class_t cc) r1.view r2.view in
List.iter (fun (t,u) -> ps_add_obligation_t cc t u) l
| If _, _
| Builtin _, _
| App_cst _, _
| Case _, _
| If _, _
| True, _
| Custom _, _
-> assert false
end
end
@ -517,8 +507,8 @@ let decompose_explain cc (e:cc_explanation): unit =
let rec explain_along_path ps (a:node) (parent_a:node) : unit =
if a!=parent_a then (
match a.n_expl with
| None -> assert false
| Some (next_a, e_a_b) ->
| E_none -> assert false
| E_some {next=next_a; expl=e_a_b} ->
decompose_explain ps e_a_b;
(* now prove [next_a = parent_a] *)
explain_along_path ps next_a parent_a
@ -530,17 +520,17 @@ let explain_loop (cc : t) : Lit.Set.t =
let a, b = Vec.pop_last cc.ps_queue in
Log.debugf 5
(fun k->k "(@[explain_loop at@ %a@ %a@])" Equiv_class.pp a Equiv_class.pp b);
assert (Equiv_class.Repr.equal (find cc a) (find cc b));
assert (Equiv_class.equal (find cc a) (find cc b));
let c = find_common_ancestor a b in
explain_along_path cc a c;
explain_along_path cc b c;
done;
cc.ps_lits
let explain_unfold cc (l:cc_explanation list): Lit.Set.t =
let explain_unfold cc (l:explanation list): Lit.Set.t =
Log.debugf 5
(fun k->k "(@[explain_confict@ (@[<hv>%a@])@])"
(Util.pp_list pp_cc_explanation) l);
(Util.pp_list Explanation.pp) l);
ps_clear cc;
List.iter (decompose_explain cc) l;
explain_loop cc

View file

@ -9,14 +9,14 @@ type t
type node = Equiv_class.t
(** Node in the congruence closure *)
type repr = Equiv_class.repr
type repr = Equiv_class.t
(** Node that is currently a representative *)
val create :
?size:int ->
on_backtrack:((unit -> unit) -> unit) ->
at_lvl_0:(unit -> bool) ->
on_merge:(repr -> repr -> cc_explanation -> unit) list ->
on_merge:(repr -> repr -> explanation -> unit) list ->
Term.state ->
t
(** Create a new congruence closure.
@ -30,7 +30,7 @@ val find : t -> node -> repr
val same_class : t -> node -> node -> bool
(** Are these two classes the same in the current CC? *)
val union : t -> node -> node -> cc_explanation -> unit
val union : t -> node -> node -> explanation -> unit
(** Merge the two equivalence classes. Will be undone on backtracking. *)
val assert_lit : t -> Lit.t -> unit
@ -48,19 +48,19 @@ val add_seq : t -> term Sequence.t -> unit
(** Add a sequence of terms to the congruence closure *)
type actions =
| Propagate of Lit.t * cc_explanation list
| Split of Lit.t list * cc_explanation list
| Propagate of Lit.t * explanation list
| Split of Lit.t list * explanation list
| Merge of node * node (* merge these two classes *)
type result =
| Sat of actions list
| Unsat of cc_explanation list
| Unsat of explanation Bag.t
(* list of direct explanations to the conflict. *)
val check : t -> result
val final_check : t -> result
val explain_unfold: t -> cc_explanation list -> Lit.Set.t
val explain_unfold: t -> explanation list -> Lit.Set.t
(** Unfold those explanations into a complete set of
literals implying them *)

View file

@ -1,9 +1,7 @@
open CDCL
open Solver_types
type t = cc_node
type repr = t
type payload = cc_node_payload
let field_expanded = Node_bits.mk_field ()
@ -11,6 +9,7 @@ let field_has_expansion_lit = Node_bits.mk_field ()
let field_is_lit = Node_bits.mk_field ()
let field_is_split = Node_bits.mk_field ()
let field_add_level_0 = Node_bits.mk_field()
let field_is_active = Node_bits.mk_field()
let () = Node_bits.freeze()
let[@inline] equal (n1:t) n2 = n1==n2
@ -19,19 +18,6 @@ let[@inline] term n = n.n_term
let[@inline] payload n = n.n_payload
let[@inline] pp out n = Term.pp out n.n_term
module Repr = struct
type node = t
type t = repr
let equal = equal
let hash = hash
let term = term
let payload = payload
let pp = pp
let[@inline] parents r = r.n_parents
let[@inline] class_ r = r.n_class
end
let make (t:term) : t =
let rec n = {
n_term=t;
@ -39,7 +25,7 @@ let make (t:term) : t =
n_class=Bag.empty;
n_parents=Bag.empty;
n_root=n;
n_expl=None;
n_expl=E_none;
n_payload=[];
} in
(* set [class(t) = {t}] *)
@ -82,5 +68,3 @@ let payload_pred ~f:p n =
let dummy = make Term.dummy
let[@inline] unsafe_repr_of_node n = n
let[@inline] unsafe_repr_seq_of_seq s = s

View file

@ -21,7 +21,6 @@ open Solver_types
*)
type t = cc_node
type repr = private t
type payload = cc_node_payload
val field_expanded : Node_bits.field
@ -42,6 +41,10 @@ val field_add_level_0 : Node_bits.field
(** Is the corresponding term to be re-added upon backtracking,
down to level 0? *)
val field_is_active : Node_bits.field
(** The term is needed for evaluation. We must try to evaluate it
or to find a value for it using the theory *)
(** {2 basics} *)
val term : t -> term
@ -50,20 +53,6 @@ val hash : t -> int
val pp : t Fmt.printer
val payload : t -> payload list
module Repr : sig
type node = t
type t = repr
val term : t -> term
val equal : t -> t -> bool
val hash : t -> int
val pp : t Fmt.printer
val payload : t -> payload list
val parents : t -> node Bag.t
val class_ : t -> node Bag.t
end
(** {2 Helpers} *)
val make : term -> t
@ -80,6 +69,4 @@ val set_payload : ?can_erase:(payload -> bool) -> t -> payload -> unit
(**/**)
val dummy : t
val unsafe_repr_of_node : t -> repr
val unsafe_repr_seq_of_seq : t Sequence.t -> repr Sequence.t
(**/**)

16
src/smt/Explanation.ml Normal file
View file

@ -0,0 +1,16 @@
open CDCL
open Solver_types
type t = explanation
let compare = cmp_exp
let equal a b = cmp_exp a b = 0
let pp = pp_explanation
module Set = CCSet.Make(struct
type t = explanation
let compare = compare
end)

View file

@ -23,13 +23,66 @@ and 'a term_cell =
| If of 'a * 'a * 'a
| Case of 'a * 'a ID.Map.t (* check head constructor *)
| Builtin of 'a builtin
| Custom of {
view: 'a term_view_custom;
tc: term_view_tc;
}
and 'a builtin =
| B_not of 'a
| B_eq of 'a * 'a
| B_and of 'a * 'a
| B_or of 'a * 'a
| B_imply of 'a * 'a
| B_and of 'a list
| B_or of 'a list
| B_imply of 'a list * 'a
(** Methods on the custom term view whose leaves are ['a].
Terms must be comparable, hashable, printable, and provide
some additional theory handles.
- [tc_t_sub] must return all immediate subterms (all ['a] contained in the term)
- [tc_t_subst] must use the function to replace all subterms (all the ['a]
returned by [tc_t_sub]) by ['b]
- [tc_t_relevant] must return a subset of [tc_t_sub] (possibly the same set).
The terms it returns will be activated and evaluated whenever possible.
Terms in [tc_t_sub t \ tc_t_relevant t] are considered for
congruence but not for evaluation.
- If [t1] and [t2] satisfy [tc_t_is_semantic] and have the same type,
then [tc_t_solve t1 t2] must succeed by returning some {!solve_result}.
- if [tc_t_equal eq a b = true], then [tc_t_explain eq a b] must
return all the pairs of equal subterms that are sufficient
for [a] and [b] to be equal.
*)
and term_view_tc = {
tc_t_pp : 'a. 'a Fmt.printer -> 'a term_view_custom Fmt.printer;
tc_t_equal : 'a. 'a CCEqual.t -> 'a term_view_custom CCEqual.t;
tc_t_hash : 'a. 'a Hash.t -> 'a term_view_custom Hash.t;
tc_t_ty : 'a. ('a -> ty) -> 'a term_view_custom -> ty;
tc_t_is_semantic : cc_node term_view_custom -> bool; (* is this a semantic term? semantic terms must be solvable *)
tc_t_solve: cc_node term_view_custom -> cc_node term_view_custom -> solve_result; (* solve an equation between classes *)
tc_t_sub : 'a. 'a term_view_custom -> 'a Sequence.t; (* iter on immediate subterms *)
tc_t_relevant : 'a. 'a term_view_custom -> 'a Sequence.t; (* iter on relevant immediate subterms *)
tc_t_subst : 'a 'b. ('a -> 'b) -> 'a term_view_custom -> 'b term_view_custom; (* substitute immediate subterms and canonize *)
tc_t_explain : 'a. 'a CCEqual.t -> 'a term_view_custom -> 'a term_view_custom -> ('a * 'a) list;
(* explain why the two views are equal *)
}
(** Custom term view for theories *)
and 'a term_view_custom = ..
(** The result of a call to {!solve}. *)
and solve_result =
| Solve_ok of {
subst: (cc_node * term) list; (** binding leaves to other terms *)
} (** Success, the two terms being equal is equivalent
to the given substitution *)
| Solve_fail of {
expl: explanation;
} (** Failure, because of the given explanation.
The two terms cannot be equal *)
(** A node of the congruence closure.
An equivalence class is represented by its "root" element,
@ -43,21 +96,32 @@ and cc_node = {
mutable n_class: cc_node Bag.t; (* terms in the same equiv class *)
mutable n_parents: cc_node Bag.t; (* parent terms of the whole equiv class *)
mutable n_root: cc_node; (* representative of congruence class (itself if a representative) *)
mutable n_expl: (cc_node * cc_explanation) option; (* the rooted forest for explanations *)
mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *)
mutable n_payload: cc_node_payload list; (* list of theory payloads *)
}
(** Theory-extensible payloads *)
and cc_node_payload = ..
and explanation_forest_link =
| E_none
| E_some of {
next: cc_node;
expl: explanation;
}
(* atomic explanation in the congruence closure *)
and cc_explanation =
| CC_reduction (* by pure reduction, tautologically equal *)
| CC_lit of lit (* because of this literal *)
| CC_congruence of cc_node * cc_node (* same shape *)
| CC_injectivity of cc_node * cc_node (* arguments of those constructors *)
| CC_reduce_eq of cc_node * cc_node (* reduce because those are equal *)
(* TODO: theory expl *)
and explanation =
| E_reduction (* by pure reduction, tautologically equal *)
| E_lit of lit (* because of this literal *)
| E_congruence of cc_node * cc_node (* these terms are congruent *)
| E_injectivity of cc_node * cc_node (* injective function *)
| E_reduce_eq of cc_node * cc_node (* reduce because those are equal by reduction *)
| E_custom of {
name: ID.t; (* name of the rule *)
args: explanation list; (* sub-explanations *)
pp: (ID.t * explanation list) Fmt.printer;
} (** Custom explanation, typically for theories *)
(* boolean literal *)
and lit = {
@ -85,7 +149,7 @@ and cst_kind =
(* what kind of constant is that? *)
and cst_defined_info =
| Cst_recursive
| Cst_recursive (* TODO: the set of Horn rules compiled from the def *)
| Cst_non_recursive
(* this is a disjunction of sufficient conditions for the existence of
@ -171,23 +235,26 @@ let hash_lit a =
let cmp_cc_node a b = term_cmp_ a.n_term b.n_term
let cmp_cc_expl a b =
let rec cmp_exp a b =
let toint = function
| CC_congruence _ -> 0 | CC_lit _ -> 1
| CC_reduction -> 2 | CC_injectivity _ -> 3
| CC_reduce_eq _ -> 5
| E_congruence _ -> 0 | E_lit _ -> 1
| E_reduction -> 2 | E_injectivity _ -> 3
| E_reduce_eq _ -> 5
| E_custom _ -> 6
in
begin match a, b with
| CC_congruence (t1,t2), CC_congruence (u1,u2) ->
| E_congruence (t1,t2), E_congruence (u1,u2) ->
CCOrd.(cmp_cc_node t1 u1 <?> (cmp_cc_node, t2, u2))
| CC_reduction, CC_reduction -> 0
| CC_lit l1, CC_lit l2 -> cmp_lit l1 l2
| CC_injectivity (t1,t2), CC_injectivity (u1,u2) ->
| E_reduction, E_reduction -> 0
| E_lit l1, E_lit l2 -> cmp_lit l1 l2
| E_injectivity (t1,t2), E_injectivity (u1,u2) ->
CCOrd.(cmp_cc_node t1 u1 <?> (cmp_cc_node, t2, u2))
| CC_reduce_eq (t1, u1), CC_reduce_eq (t2,u2) ->
| E_reduce_eq (t1, u1), E_reduce_eq (t2,u2) ->
CCOrd.(cmp_cc_node t1 t2 <?> (cmp_cc_node, u1, u2))
| CC_congruence _, _ | CC_lit _, _ | CC_reduction, _
| CC_injectivity _, _ | CC_reduce_eq _, _
| E_custom r1, E_custom r2 ->
CCOrd.(ID.compare r1.name r2.name <?> (list cmp_exp, r1.args, r2.args))
| E_congruence _, _ | E_lit _, _ | E_reduction, _
| E_injectivity _, _ | E_reduce_eq _, _ | E_custom _, _
-> CCInt.compare (toint a)(toint b)
end
@ -237,14 +304,15 @@ let pp_term_top ~ids out t =
Fmt.fprintf out "(@[match %a@ (@[<hv>%a@])@])"
pp t print_map (ID.Map.to_seq m)
| Builtin (B_not t) -> Fmt.fprintf out "(@[<hv1>not@ %a@])" pp t
| Builtin (B_and (a,b)) ->
Fmt.fprintf out "(@[<hv1>and@ %a@ %a@])" pp a pp b
| Builtin (B_or (a,b)) ->
Fmt.fprintf out "(@[<hv1>or@ %a@ %a@])" pp a pp b
| Builtin (B_and l) ->
Fmt.fprintf out "(@[<hv1>and@ %a])" (Util.pp_list pp) l
| Builtin (B_or l) ->
Fmt.fprintf out "(@[<hv1>or@ %a@])" (Util.pp_list pp) l
| Builtin (B_imply (a,b)) ->
Fmt.fprintf out "(@[<hv1>=>@ %a@ %a@])" pp a pp b
Fmt.fprintf out "(@[<hv1>=>@ %a@ %a@])" (Util.pp_list pp) a pp b
| Builtin (B_eq (a,b)) ->
Fmt.fprintf out "(@[<hv1>=@ %a@ %a@])" pp a pp b
| Custom {view; tc} -> tc.tc_t_pp pp out view
and pp_id =
if ids then ID.pp else ID.pp_name
in
@ -263,12 +331,13 @@ let pp_lit out l =
let pp_cc_node out n = pp_term out n.n_term
let pp_cc_explanation out (e:cc_explanation) = match e with
| CC_reduction -> Fmt.string out "reduction"
| CC_lit lit -> pp_lit out lit
| CC_congruence (a,b) ->
let pp_explanation out (e:explanation) = match e with
| E_reduction -> Fmt.string out "reduction"
| E_lit lit -> pp_lit out lit
| E_congruence (a,b) ->
Format.fprintf out "(@[<hv1>congruence@ %a@ %a@])" pp_cc_node a pp_cc_node b
| CC_injectivity (a,b) ->
| E_injectivity (a,b) ->
Format.fprintf out "(@[<hv1>injectivity@ %a@ %a@])" pp_cc_node a pp_cc_node b
| CC_reduce_eq (t, u) ->
| E_reduce_eq (t, u) ->
Format.fprintf out "(@[<hv1>reduce_eq@ %a@ %a@])" pp_cc_node t pp_cc_node u
| E_custom {name; args; pp} -> pp out (name,args)

View file

@ -1,5 +1,4 @@
open CDCL
open Solver_types
type t = term
@ -62,9 +61,19 @@ let if_ st a b c = make st (Term_cell.if_ a b c)
let not_ st t = make st (Term_cell.not_ t)
let and_ st a b = make st (Term_cell.and_ a b)
let or_ st a b = make st (Term_cell.or_ a b)
let imply st a b = make st (Term_cell.imply a b)
let and_l st = function
| [] -> true_ st
| [t] -> t
| l -> make st (Term_cell.and_ l)
let or_l st = function
| [] -> false_ st
| [t] -> t
| l -> make st (Term_cell.or_ l)
let and_ st a b = and_l st [a;b]
let or_ st a b = and_l st [a;b]
let imply st a b = match a with [] -> b | _ -> make st (Term_cell.imply a b)
let eq st a b = make st (Term_cell.eq a b)
let neq st a b = not_ st (eq st a b)
let builtin st b = make st (Term_cell.builtin b)
@ -80,16 +89,6 @@ let abs t : t * bool = match t.term_cell with
| Builtin (B_not t) -> t, false
| _ -> t, true
let rec and_l st = function
| [] -> true_ st
| [t] -> t
| a :: l -> and_ st a (and_l st l)
let or_l st = function
| [] -> false_ st
| [t] -> t
| a :: l -> List.fold_left (or_ st) a l
let fold_map_builtin
(f:'a -> term -> 'a * term) (acc:'a) (b:t builtin): 'a * t builtin =
let fold_binary acc a b =
@ -101,17 +100,18 @@ let fold_map_builtin
| B_not t ->
let acc, t' = f acc t in
acc, B_not t'
| B_and (a,b) ->
let acc, a, b = fold_binary acc a b in
acc, B_and (a,b)
| B_or (a,b) ->
let acc, a, b = fold_binary acc a b in
acc, B_or (a, b)
| B_and l ->
let acc, l = CCList.fold_map f acc l in
acc, B_and l
| B_or l ->
let acc, l = CCList.fold_map f acc l in
acc, B_or l
| B_eq (a,b) ->
let acc, a, b = fold_binary acc a b in
acc, B_eq (a, b)
| B_imply (a,b) ->
let acc, a, b = fold_binary acc a b in
let acc, a = CCList.fold_map f acc a in
let acc, b = f acc b in
acc, B_imply (a, b)
let is_const t = match t.term_cell with
@ -124,10 +124,9 @@ let map_builtin f b =
let builtin_to_seq b yield = match b with
| B_not t -> yield t
| B_or (a,b)
| B_imply (a,b)
| B_or l | B_and l -> List.iter yield l
| B_imply (a,b) -> List.iter yield a; yield b
| B_eq (a,b) -> yield a; yield b
| B_and (a,b) -> yield a; yield b
module As_key = struct
type t = term
@ -150,6 +149,7 @@ let to_seq t yield =
aux t;
ID.Map.iter (fun _ rhs -> aux rhs) m
| Builtin b -> builtin_to_seq b aux
| Custom {view;tc} -> tc.tc_t_sub view aux
in
aux t
@ -181,12 +181,8 @@ let as_unif (t:term): unif_form = match t.term_cell with
Unif_cstor (c,cstor,a)
| _ -> Unif_none
let fpf = Format.fprintf
let pp = Solver_types.pp_term
let dummy : t = {
term_id= -1;
term_ty=Ty.prop;

View file

@ -25,7 +25,7 @@ val builtin : state -> t builtin -> t
val and_ : state -> t -> t -> t
val or_ : state -> t -> t -> t
val not_ : state -> t -> t
val imply : state -> t -> t -> t
val imply : state -> t list -> t -> t
val eq : state -> t -> t -> t
val neq : state -> t -> t -> t
val and_eager : state -> t -> t -> t (* evaluate left argument first *)

View file

@ -27,10 +27,11 @@ module Make_eq(A : ARG) = struct
in
Hash.combine3 8 (sub_hash u) hash_m
| Builtin (B_not a) -> Hash.combine2 20 (sub_hash a)
| Builtin (B_and (t1,t2)) -> Hash.combine3 21 (sub_hash t1) (sub_hash t2)
| Builtin (B_or (t1,t2)) -> Hash.combine3 22 (sub_hash t1) (sub_hash t2)
| Builtin (B_imply (t1,t2)) -> Hash.combine3 23 (sub_hash t1) (sub_hash t2)
| Builtin (B_and l) -> Hash.combine2 21 (Hash.list sub_hash l)
| Builtin (B_or l) -> Hash.combine2 22 (Hash.list sub_hash l)
| Builtin (B_imply (l1,t2)) -> Hash.combine3 23 (Hash.list sub_hash l1) (sub_hash t2)
| Builtin (B_eq (t1,t2)) -> Hash.combine3 24 (sub_hash t1) (sub_hash t2)
| Custom {view;tc} -> tc.tc_t_hash sub_hash view
(* equality that relies on physical equality of subterms *)
let equal (a:A.t term_cell) b : bool = match a, b with
@ -51,18 +52,21 @@ module Make_eq(A : ARG) = struct
| Builtin b1, Builtin b2 ->
begin match b1, b2 with
| B_not a1, B_not a2 -> sub_eq a1 a2
| B_and (a1,b1), B_and (a2,b2)
| B_or (a1,b1), B_or (a2,b2)
| B_eq (a1,b1), B_eq (a2,b2)
| B_imply (a1,b1), B_imply (a2,b2) -> sub_eq a1 a2 && sub_eq b1 b2
| B_and l1, B_and l2
| B_or l1, B_or l2 -> CCEqual.list sub_eq l1 l2
| B_eq (a1,b1), B_eq (a2,b2) -> sub_eq a1 a2 && sub_eq b1 b2
| B_imply (a1,b1), B_imply (a2,b2) -> CCEqual.list sub_eq a1 a2 && sub_eq b1 b2
| B_not _, _ | B_and _, _ | B_eq _, _
| B_or _, _ | B_imply _, _ -> false
end
| Custom r1, Custom r2 ->
r1.tc.tc_t_equal sub_eq r1.view r2.view
| True, _
| App_cst _, _
| If _, _
| Case _, _
| Builtin _, _
| Custom _, _
-> false
end[@@inline]
@ -90,24 +94,26 @@ let cstor_proj cstor i t =
app_cst p (IArray.singleton t)
let builtin b =
let mk_ x = Builtin x in
(* normalize a bit *)
let b = match b with
| B_eq (a,b) when a.term_id > b.term_id -> B_eq (b,a)
| B_and (a,b) when a.term_id > b.term_id -> B_and (b,a)
| B_or (a,b) when a.term_id > b.term_id -> B_or (b,a)
| _ -> b
in
Builtin b
begin match b with
| B_imply ([], x) -> x.term_cell
| B_eq (a,b) when a.term_id = b.term_id -> true_
| B_eq (a,b) when a.term_id > b.term_id -> mk_ @@ B_eq (b,a)
| _ -> mk_ b
end
let not_ t = match t.term_cell with
| Builtin (B_not t') -> t'.term_cell
| _ -> builtin (B_not t)
let and_ a b = builtin (B_and (a,b))
let or_ a b = builtin (B_or (a,b))
let and_ l = builtin (B_and l)
let or_ l = builtin (B_or l)
let imply a b = builtin (B_imply (a,b))
let eq a b = builtin (B_eq (a,b))
let custom ~tc view = Custom {view;tc}
(* type of an application *)
let rec app_ty_ ty l : Ty.t = match Ty.view ty, l with
| _, [] -> ty
@ -132,6 +138,7 @@ let ty (t:t): Ty.t = match t with
let _, rhs = ID.Map.choose m in
rhs.term_ty
| Builtin _ -> Ty.prop
| Custom {view;tc} -> tc.tc_t_ty (fun t -> t.term_ty) view
module Tbl = CCHashtbl.Make(struct
type t = term term_cell

View file

@ -15,11 +15,12 @@ val cstor_proj : data_cstor -> int -> term -> t
val case : term -> term ID.Map.t -> t
val if_ : term -> term -> term -> t
val builtin : term builtin -> t
val and_ : term -> term -> t
val or_ : term -> term -> t
val and_ : term list -> t
val or_ : term list -> t
val not_ : term -> t
val imply : term -> term -> t
val imply : term list -> term -> t
val eq : term -> term -> t
val custom : tc:term_view_tc -> term term_view_custom -> t
val ty : t -> Ty.t
(** Compute the type of this term cell. Not totally free *)