diff --git a/src/core/Solver_intf.ml b/src/core/Solver_intf.ml index 76db6216..3a6678a2 100644 --- a/src/core/Solver_intf.ml +++ b/src/core/Solver_intf.ml @@ -87,6 +87,7 @@ module type S = sig (** Add the list of clauses to the current set of assumptions. Modifies the sat solver state in place. *) + (* TODO: provide a local, backtrackable version *) val add_clause : t -> clause -> unit (** Lower level addition of clauses *) diff --git a/src/smt/Congruence_closure.ml b/src/smt/Congruence_closure.ml index 0dcde83f..af6d3761 100644 --- a/src/smt/Congruence_closure.ml +++ b/src/smt/Congruence_closure.ml @@ -3,7 +3,7 @@ open CDCL open Solver_types type node = Equiv_class.t -type repr = Equiv_class.repr +type repr = Equiv_class.t (** A signature is a shallow term shape where immediate subterms are representative *) @@ -14,12 +14,12 @@ end module Sig_tbl = CCHashtbl.Make(Signature) -type merge_op = node * node * cc_explanation +type merge_op = node * node * explanation (* a merge operation to perform *) type actions = - | Propagate of Lit.t * cc_explanation list - | Split of Lit.t list * cc_explanation list + | Propagate of Lit.t * explanation list + | Split of Lit.t list * explanation list | Merge of node * node (* merge these two classes *) type t = { @@ -38,7 +38,7 @@ type t = { (* register a function to be called when we backtrack *) at_lvl_0: unit -> bool; (* currently at level 0? *) - on_merge: (repr -> repr -> cc_explanation -> unit) list; + on_merge: (repr -> repr -> explanation -> unit) list; (* callbacks to call when we merge classes *) pending: node Vec.t; (* nodes to check, maybe their new signature is in {!signatures_tbl} *) @@ -59,11 +59,6 @@ type t = { several times. See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *) -module CC_expl_set = CCSet.Make(struct - type t = cc_explanation - let compare = Solver_types.cmp_cc_expl - end) - let[@inline] is_root_ (n:node) : bool = n.n_root == n let[@inline] size_ (r:repr) = @@ -78,7 +73,7 @@ let[@inline] mem (cc:t) (t:term): bool = (* find representative, recursively, and perform path compression *) let rec find_rec cc (n:node) : repr = if n==n.n_root then ( - Equiv_class.unsafe_repr_of_node n + n ) else ( let old_root = n.n_root in let root = find_rec cc old_root in @@ -104,19 +99,20 @@ let[@inline] get_ cc (t:term) : node = (* non-recursive, inlinable function for [find] *) let[@inline] find st (n:node) : repr = - if n == n.n_root - then (Equiv_class.unsafe_repr_of_node n) - else find_rec st n + if n == n.n_root then n else find_rec st n let[@inline] find_tn cc (t:term) : repr = get_ cc t |> find cc -let[@inline] find_tt cc (t:term) : term = find_tn cc t |> Equiv_class.Repr.term +let[@inline] find_tt cc (t:term) : term = find_tn cc t |> Equiv_class.term let[@inline] same_class cc (n1:node)(n2:node): bool = - Equiv_class.Repr.equal (find cc n1) (find cc n2) + Equiv_class.equal (find cc n1) (find cc n2) + +let[@inline] same_class_t cc (t1:term)(t2:term): bool = + Equiv_class.equal (find_tn cc t1) (find_tn cc t2) (* compute signature *) let signature cc (t:term): node term_cell option = - let find = (find_tn cc :> term -> node) in + let find = find_tn cc in begin match Term.cell t with | True | Builtin _ -> None @@ -124,6 +120,8 @@ let signature cc (t:term): node term_cell option = | App_cst (f, a) -> App_cst (f, IArray.map find a) |> CCOpt.return | If (a,b,c) -> If (find a, get_ cc b, get_ cc c) |> CCOpt.return | Case (t, m) -> Case (find t, ID.Map.map (get_ cc) m) |> CCOpt.return + | Custom {view;tc} -> + Custom {tc; view=tc.tc_t_subst find view} |> CCOpt.return end (* find whether the given (parent) term corresponds to some signature @@ -151,7 +149,7 @@ let add_signature cc (t:term) (r:repr): unit = match signature cc t with ); Sig_tbl.add cc.signatures_tbl s r; | Some r' -> - assert (Equiv_class.Repr.equal r r'); + assert (Equiv_class.equal r r'); end let is_done (cc:t): bool = @@ -165,24 +163,24 @@ let push_pending cc t : unit = let push_combine cc t u e : unit = Log.debugf 5 (fun k->k "(@[push_combine@ %a@ %a@ expl: %a@])" - Equiv_class.pp t Equiv_class.pp u pp_cc_explanation e); + Equiv_class.pp t Equiv_class.pp u Explanation.pp e); Vec.push cc.combine (t,u,e) -let push_split cc (lits:lit list) (expl:cc_explanation list): unit = +let push_split cc (lits:lit list) (expl:explanation list): unit = Log.debugf 5 (fun k->k "(@[push_split@ (@[%a@])@ expl: (@[%a@])@])" - (Util.pp_list Lit.pp) lits (Util.pp_list pp_cc_explanation) expl); + (Util.pp_list Lit.pp) lits (Util.pp_list Explanation.pp) expl); let l = Split (lits, expl) in cc.actions <- l :: cc.actions -let push_propagation cc (lit:lit) (expl:cc_explanation list): unit = +let push_propagation cc (lit:lit) (expl:explanation list): unit = Log.debugf 5 (fun k->k "(@[push_propagate@ %a@ expl: (@[%a@])@])" - Lit.pp lit (Util.pp_list pp_cc_explanation) expl); + Lit.pp lit (Util.pp_list Explanation.pp) expl); let l = Propagate (lit,expl) in cc.actions <- l :: cc.actions -let[@inline] union cc (a:node) (b:node) (e:cc_explanation): unit = +let[@inline] union cc (a:node) (b:node) (e:explanation): unit = if not (same_class cc a b) then ( push_combine cc a b e; (* start by merging [a=b] *) ) @@ -196,11 +194,11 @@ let rec reroot_expl cc (n:node): unit = cc.on_backtrack (fun () -> n.n_expl <- old_expl); ); begin match old_expl with - | None -> () (* already root *) - | Some (u, e_n_u) -> + | E_none -> () (* already root *) + | E_some {next=u; expl=e_n_u} -> reroot_expl cc u; - u.n_expl <- Some (n, e_n_u); - n.n_expl <- None; + u.n_expl <- E_some {next=n; expl=e_n_u}; + n.n_expl <- E_none; end (* TODO: @@ -208,19 +206,18 @@ let rec reroot_expl cc (n:node): unit = - also, obtain merges of CC via callbacks / [pop_merges] afterwards? *) -exception Exn_unsat of cc_explanation list +exception Exn_unsat of explanation Bag.t -let unsat (e:cc_explanation list): _ = raise (Exn_unsat e) +let unsat (e:explanation Bag.t): _ = raise (Exn_unsat e) type result = | Sat of actions list - | Unsat of cc_explanation list + | Unsat of explanation Bag.t (* list of direct explanations to the conflict. *) let[@inline] all_classes cc : repr Sequence.t = Term.Tbl.values cc.tbl |> Sequence.filter is_root_ - |> Equiv_class.unsafe_repr_seq_of_seq (* main CC algo: add terms from [pending] to the signature table, check for collisions *) @@ -236,7 +233,7 @@ let rec update_pending (cc:t): result = add_signature cc n.n_term (find cc n) | Some u -> (* must combine [t] with [r] *) - push_combine cc n (u:>node) (CC_congruence (n,(u:>node))) + push_combine cc n u(E_congruence (n,u)) end; (* FIXME: when to actually evaluate? eval_pending cc; @@ -257,8 +254,8 @@ and update_combine cc = let a, b, e_ab = Vec.pop_last cc.combine in let ra = find cc a in let rb = find cc b in - if not (Equiv_class.Repr.equal ra rb) then ( - assert (is_root_ (ra:>node)); + if not (Equiv_class.equal ra rb) then ( + assert (is_root_ ra); assert (is_root_ (rb:>node)); (* We will merge [r_from] into [r_into]. we try to ensure that [size ra <= size rb] in general, unless @@ -296,11 +293,11 @@ and update_combine cc = (* update explanations (a -> b), arbitrarily *) begin reroot_expl cc a; - assert (a.n_expl = None); + assert (a.n_expl = E_none); if not (cc.at_lvl_0 ()) then ( - cc.on_backtrack (fun () -> a.n_expl <- None); + cc.on_backtrack (fun () -> a.n_expl <- E_none); ); - a.n_expl <- Some (b, e_ab); + a.n_expl <- E_some {next=b; expl=e_ab}; end; (* notify listeners of the merge *) notify_merge cc r_from ~into:r_into e_ab; @@ -312,7 +309,7 @@ and update_combine cc = (* Checks if [ra] and [~into] have compatible normal forms and can be merged w.r.t. the theories. Side effect: also pushes sub-tasks *) -and notify_merge cc (ra:repr) ~into:(rb:repr) (e:cc_explanation): unit = +and notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit = assert (is_root_ (ra:>node)); assert (is_root_ (rb:>node)); List.iter @@ -366,6 +363,7 @@ and add_new_term cc (t:term) : node = add_sub_t c | Case (u, _) -> add_sub_t u | Builtin b -> Term.builtin_to_seq b add_sub_t + | Custom {view;tc} -> tc.tc_t_sub view add_sub_t end; (* remove term when we backtrack *) if not (cc.at_lvl_0 ()) then ( @@ -399,7 +397,7 @@ let assert_lit cc lit : unit = match Lit.view lit with (* equate t and true/false *) let rhs = if sign then true_ cc else false_ cc in let n = add cc t in - push_combine cc n rhs (CC_lit lit); + push_combine cc n rhs (E_lit lit); () let create ?(size=2048) ~on_backtrack ~at_lvl_0 ~on_merge (tst:Term.state) : t = @@ -413,7 +411,7 @@ let create ?(size=2048) ~on_backtrack ~at_lvl_0 ~on_merge (tst:Term.state) : t = on_backtrack; at_lvl_0; pending=Vec.make_empty Equiv_class.dummy; - combine= Vec.make_empty (nd,nd,CC_reduce_eq(nd,nd)); + combine= Vec.make_empty (nd,nd,E_reduce_eq(nd,nd)); actions=[]; ps_lits=Lit.Set.empty; ps_queue=Vec.make_empty (nd,nd); @@ -426,8 +424,8 @@ let create ?(size=2048) ~on_backtrack ~at_lvl_0 ~on_merge (tst:Term.state) : t = (* distance from [t] to its root in the proof forest *) let[@inline][@unroll 2] rec distance_to_root (n:node): int = match n.n_expl with - | None -> 0 - | Some (t', _) -> 1 + distance_to_root t' + | E_none -> 0 + | E_some {next=t'; _} -> 1 + distance_to_root t' (* find the closest common ancestor of [a] and [b] in the proof forest *) let find_common_ancestor (a:node) (b:node) : node = @@ -437,8 +435,8 @@ let find_common_ancestor (a:node) (b:node) : node = let rec drop_ n t = if n=0 then t else match t.n_expl with - | None -> assert false - | Some (t', _) -> drop_ (n-1) t' + | E_none -> assert false + | E_some {next=t'; _} -> drop_ (n-1) t' in (* reduce to the problem where [a] and [b] have the same distance to root *) let a, b = @@ -450,18 +448,13 @@ let find_common_ancestor (a:node) (b:node) : node = let rec aux_same_dist a b = if a==b then a else match a.n_expl, b.n_expl with - | None, _ | _, None -> assert false - | Some (a', _), Some (b', _) -> aux_same_dist a' b' + | E_none, _ | _, E_none -> assert false + | E_some {next=a'; _}, E_some {next=b'; _} -> aux_same_dist a' b' in aux_same_dist a b let[@inline] ps_add_obligation (cc:t) a b = Vec.push cc.ps_queue (a,b) let[@inline] ps_add_lit ps l = ps.ps_lits <- Lit.Set.add l ps.ps_lits -let[@inline] ps_add_expl ps e = match e with - | CC_lit lit -> ps_add_lit ps lit - | CC_reduce_eq _ | CC_congruence _ - | CC_injectivity _ | CC_reduction - -> () and ps_add_obligation_t cc (t1:term) (t2:term) = let n1 = get_ cc t1 in @@ -473,41 +466,38 @@ let ps_clear (cc:t) = Vec.clear cc.ps_queue; () -let decompose_explain cc (e:cc_explanation): unit = - Log.debugf 5 (fun k->k "(@[decompose_expl@ %a@])" pp_cc_explanation e); - ps_add_expl cc e; +let rec decompose_explain cc (e:explanation): unit = + Log.debugf 5 (fun k->k "(@[decompose_expl@ %a@])" Explanation.pp e); begin match e with - | CC_reduction - | CC_lit _ -> () - | CC_reduce_eq (a, b) -> + | E_reduction -> () + | E_lit lit -> ps_add_lit cc lit + | E_custom {args;_} -> + (* decompose sub-expls *) + List.iter (decompose_explain cc) args + | E_reduce_eq (a, b) -> ps_add_obligation cc a b; - | CC_injectivity (t1,t2) - (* FIXME: should this be different from CC_congruence? just explain why t1==t2? *) - | CC_congruence (t1,t2) -> + | E_injectivity (t1,t2) -> + (* arguments of [t1], [t2] are equal by injectivity, so we + just need to explain why [t1=t2] *) + ps_add_obligation cc t1 t2 + | E_congruence (t1,t2) -> + (* [t1] and [t2] must be applications of the same symbol to + arguments that are pairwise equal *) begin match t1.n_term.term_cell, t2.n_term.term_cell with - | True, _ -> assert false (* no congruence here *) | App_cst (f1, a1), App_cst (f2, a2) -> assert (Cst.equal f1 f2); assert (IArray.length a1 = IArray.length a2); IArray.iter2 (ps_add_obligation_t cc) a1 a2 - | Case (_t1, _m1), Case (_t2, _m2) -> - assert false - (* TODO: this should never happen - ps_add_obligation ps t1 t2; - ID.Map.iter - (fun id rhs1 -> - let rhs2 = ID.Map.find id m2 in - ps_add_obligation ps rhs1 rhs2) - m1; - *) - | If (a1,b1,c1), If (a2,b2,c2) -> - ps_add_obligation_t cc a1 a2; - ps_add_obligation_t cc b1 b2; - ps_add_obligation_t cc c1 c2; - | Builtin _, _ -> assert false + | Custom r1, Custom r2 -> + (* ask the theory to explain why [r1 = r2] *) + let l = r1.tc.tc_t_explain (same_class_t cc) r1.view r2.view in + List.iter (fun (t,u) -> ps_add_obligation_t cc t u) l + | If _, _ + | Builtin _, _ | App_cst _, _ | Case _, _ - | If _, _ + | True, _ + | Custom _, _ -> assert false end end @@ -517,8 +507,8 @@ let decompose_explain cc (e:cc_explanation): unit = let rec explain_along_path ps (a:node) (parent_a:node) : unit = if a!=parent_a then ( match a.n_expl with - | None -> assert false - | Some (next_a, e_a_b) -> + | E_none -> assert false + | E_some {next=next_a; expl=e_a_b} -> decompose_explain ps e_a_b; (* now prove [next_a = parent_a] *) explain_along_path ps next_a parent_a @@ -530,17 +520,17 @@ let explain_loop (cc : t) : Lit.Set.t = let a, b = Vec.pop_last cc.ps_queue in Log.debugf 5 (fun k->k "(@[explain_loop at@ %a@ %a@])" Equiv_class.pp a Equiv_class.pp b); - assert (Equiv_class.Repr.equal (find cc a) (find cc b)); + assert (Equiv_class.equal (find cc a) (find cc b)); let c = find_common_ancestor a b in explain_along_path cc a c; explain_along_path cc b c; done; cc.ps_lits -let explain_unfold cc (l:cc_explanation list): Lit.Set.t = +let explain_unfold cc (l:explanation list): Lit.Set.t = Log.debugf 5 (fun k->k "(@[explain_confict@ (@[%a@])@])" - (Util.pp_list pp_cc_explanation) l); + (Util.pp_list Explanation.pp) l); ps_clear cc; List.iter (decompose_explain cc) l; explain_loop cc diff --git a/src/smt/Congruence_closure.mli b/src/smt/Congruence_closure.mli index fc864895..90a4f736 100644 --- a/src/smt/Congruence_closure.mli +++ b/src/smt/Congruence_closure.mli @@ -9,14 +9,14 @@ type t type node = Equiv_class.t (** Node in the congruence closure *) -type repr = Equiv_class.repr +type repr = Equiv_class.t (** Node that is currently a representative *) val create : ?size:int -> on_backtrack:((unit -> unit) -> unit) -> at_lvl_0:(unit -> bool) -> - on_merge:(repr -> repr -> cc_explanation -> unit) list -> + on_merge:(repr -> repr -> explanation -> unit) list -> Term.state -> t (** Create a new congruence closure. @@ -30,7 +30,7 @@ val find : t -> node -> repr val same_class : t -> node -> node -> bool (** Are these two classes the same in the current CC? *) -val union : t -> node -> node -> cc_explanation -> unit +val union : t -> node -> node -> explanation -> unit (** Merge the two equivalence classes. Will be undone on backtracking. *) val assert_lit : t -> Lit.t -> unit @@ -48,19 +48,19 @@ val add_seq : t -> term Sequence.t -> unit (** Add a sequence of terms to the congruence closure *) type actions = - | Propagate of Lit.t * cc_explanation list - | Split of Lit.t list * cc_explanation list + | Propagate of Lit.t * explanation list + | Split of Lit.t list * explanation list | Merge of node * node (* merge these two classes *) type result = | Sat of actions list - | Unsat of cc_explanation list + | Unsat of explanation Bag.t (* list of direct explanations to the conflict. *) val check : t -> result val final_check : t -> result -val explain_unfold: t -> cc_explanation list -> Lit.Set.t +val explain_unfold: t -> explanation list -> Lit.Set.t (** Unfold those explanations into a complete set of literals implying them *) diff --git a/src/smt/Equiv_class.ml b/src/smt/Equiv_class.ml index 74fa1f7a..d2e7a3c4 100644 --- a/src/smt/Equiv_class.ml +++ b/src/smt/Equiv_class.ml @@ -1,9 +1,7 @@ -open CDCL open Solver_types type t = cc_node -type repr = t type payload = cc_node_payload let field_expanded = Node_bits.mk_field () @@ -11,6 +9,7 @@ let field_has_expansion_lit = Node_bits.mk_field () let field_is_lit = Node_bits.mk_field () let field_is_split = Node_bits.mk_field () let field_add_level_0 = Node_bits.mk_field() +let field_is_active = Node_bits.mk_field() let () = Node_bits.freeze() let[@inline] equal (n1:t) n2 = n1==n2 @@ -19,19 +18,6 @@ let[@inline] term n = n.n_term let[@inline] payload n = n.n_payload let[@inline] pp out n = Term.pp out n.n_term -module Repr = struct - type node = t - type t = repr - let equal = equal - let hash = hash - let term = term - let payload = payload - let pp = pp - - let[@inline] parents r = r.n_parents - let[@inline] class_ r = r.n_class -end - let make (t:term) : t = let rec n = { n_term=t; @@ -39,7 +25,7 @@ let make (t:term) : t = n_class=Bag.empty; n_parents=Bag.empty; n_root=n; - n_expl=None; + n_expl=E_none; n_payload=[]; } in (* set [class(t) = {t}] *) @@ -82,5 +68,3 @@ let payload_pred ~f:p n = let dummy = make Term.dummy -let[@inline] unsafe_repr_of_node n = n -let[@inline] unsafe_repr_seq_of_seq s = s diff --git a/src/smt/Equiv_class.mli b/src/smt/Equiv_class.mli index 6ec193ad..fe8e2513 100644 --- a/src/smt/Equiv_class.mli +++ b/src/smt/Equiv_class.mli @@ -21,7 +21,6 @@ open Solver_types *) type t = cc_node -type repr = private t type payload = cc_node_payload val field_expanded : Node_bits.field @@ -42,6 +41,10 @@ val field_add_level_0 : Node_bits.field (** Is the corresponding term to be re-added upon backtracking, down to level 0? *) +val field_is_active : Node_bits.field +(** The term is needed for evaluation. We must try to evaluate it + or to find a value for it using the theory *) + (** {2 basics} *) val term : t -> term @@ -50,20 +53,6 @@ val hash : t -> int val pp : t Fmt.printer val payload : t -> payload list -module Repr : sig - type node = t - type t = repr - - val term : t -> term - val equal : t -> t -> bool - val hash : t -> int - val pp : t Fmt.printer - val payload : t -> payload list - - val parents : t -> node Bag.t - val class_ : t -> node Bag.t -end - (** {2 Helpers} *) val make : term -> t @@ -80,6 +69,4 @@ val set_payload : ?can_erase:(payload -> bool) -> t -> payload -> unit (**/**) val dummy : t -val unsafe_repr_of_node : t -> repr -val unsafe_repr_seq_of_seq : t Sequence.t -> repr Sequence.t (**/**) diff --git a/src/smt/Explanation.ml b/src/smt/Explanation.ml new file mode 100644 index 00000000..f344b79a --- /dev/null +++ b/src/smt/Explanation.ml @@ -0,0 +1,16 @@ + +open CDCL +open Solver_types + +type t = explanation + +let compare = cmp_exp +let equal a b = cmp_exp a b = 0 + +let pp = pp_explanation + +module Set = CCSet.Make(struct + type t = explanation + let compare = compare + end) + diff --git a/src/smt/Solver_types.ml b/src/smt/Solver_types.ml index 889d2149..99953868 100644 --- a/src/smt/Solver_types.ml +++ b/src/smt/Solver_types.ml @@ -23,13 +23,66 @@ and 'a term_cell = | If of 'a * 'a * 'a | Case of 'a * 'a ID.Map.t (* check head constructor *) | Builtin of 'a builtin + | Custom of { + view: 'a term_view_custom; + tc: term_view_tc; + } and 'a builtin = | B_not of 'a | B_eq of 'a * 'a - | B_and of 'a * 'a - | B_or of 'a * 'a - | B_imply of 'a * 'a + | B_and of 'a list + | B_or of 'a list + | B_imply of 'a list * 'a + +(** Methods on the custom term view whose leaves are ['a]. + Terms must be comparable, hashable, printable, and provide + some additional theory handles. + + - [tc_t_sub] must return all immediate subterms (all ['a] contained in the term) + + - [tc_t_subst] must use the function to replace all subterms (all the ['a] + returned by [tc_t_sub]) by ['b] + + - [tc_t_relevant] must return a subset of [tc_t_sub] (possibly the same set). + The terms it returns will be activated and evaluated whenever possible. + Terms in [tc_t_sub t \ tc_t_relevant t] are considered for + congruence but not for evaluation. + + - If [t1] and [t2] satisfy [tc_t_is_semantic] and have the same type, + then [tc_t_solve t1 t2] must succeed by returning some {!solve_result}. + + - if [tc_t_equal eq a b = true], then [tc_t_explain eq a b] must + return all the pairs of equal subterms that are sufficient + for [a] and [b] to be equal. +*) +and term_view_tc = { + tc_t_pp : 'a. 'a Fmt.printer -> 'a term_view_custom Fmt.printer; + tc_t_equal : 'a. 'a CCEqual.t -> 'a term_view_custom CCEqual.t; + tc_t_hash : 'a. 'a Hash.t -> 'a term_view_custom Hash.t; + tc_t_ty : 'a. ('a -> ty) -> 'a term_view_custom -> ty; + tc_t_is_semantic : cc_node term_view_custom -> bool; (* is this a semantic term? semantic terms must be solvable *) + tc_t_solve: cc_node term_view_custom -> cc_node term_view_custom -> solve_result; (* solve an equation between classes *) + tc_t_sub : 'a. 'a term_view_custom -> 'a Sequence.t; (* iter on immediate subterms *) + tc_t_relevant : 'a. 'a term_view_custom -> 'a Sequence.t; (* iter on relevant immediate subterms *) + tc_t_subst : 'a 'b. ('a -> 'b) -> 'a term_view_custom -> 'b term_view_custom; (* substitute immediate subterms and canonize *) + tc_t_explain : 'a. 'a CCEqual.t -> 'a term_view_custom -> 'a term_view_custom -> ('a * 'a) list; + (* explain why the two views are equal *) +} + +(** Custom term view for theories *) +and 'a term_view_custom = .. + +(** The result of a call to {!solve}. *) +and solve_result = + | Solve_ok of { + subst: (cc_node * term) list; (** binding leaves to other terms *) + } (** Success, the two terms being equal is equivalent + to the given substitution *) + | Solve_fail of { + expl: explanation; + } (** Failure, because of the given explanation. + The two terms cannot be equal *) (** A node of the congruence closure. An equivalence class is represented by its "root" element, @@ -43,21 +96,32 @@ and cc_node = { mutable n_class: cc_node Bag.t; (* terms in the same equiv class *) mutable n_parents: cc_node Bag.t; (* parent terms of the whole equiv class *) mutable n_root: cc_node; (* representative of congruence class (itself if a representative) *) - mutable n_expl: (cc_node * cc_explanation) option; (* the rooted forest for explanations *) + mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *) mutable n_payload: cc_node_payload list; (* list of theory payloads *) } (** Theory-extensible payloads *) and cc_node_payload = .. +and explanation_forest_link = + | E_none + | E_some of { + next: cc_node; + expl: explanation; + } + (* atomic explanation in the congruence closure *) -and cc_explanation = - | CC_reduction (* by pure reduction, tautologically equal *) - | CC_lit of lit (* because of this literal *) - | CC_congruence of cc_node * cc_node (* same shape *) - | CC_injectivity of cc_node * cc_node (* arguments of those constructors *) - | CC_reduce_eq of cc_node * cc_node (* reduce because those are equal *) -(* TODO: theory expl *) +and explanation = + | E_reduction (* by pure reduction, tautologically equal *) + | E_lit of lit (* because of this literal *) + | E_congruence of cc_node * cc_node (* these terms are congruent *) + | E_injectivity of cc_node * cc_node (* injective function *) + | E_reduce_eq of cc_node * cc_node (* reduce because those are equal by reduction *) + | E_custom of { + name: ID.t; (* name of the rule *) + args: explanation list; (* sub-explanations *) + pp: (ID.t * explanation list) Fmt.printer; + } (** Custom explanation, typically for theories *) (* boolean literal *) and lit = { @@ -85,7 +149,7 @@ and cst_kind = (* what kind of constant is that? *) and cst_defined_info = - | Cst_recursive + | Cst_recursive (* TODO: the set of Horn rules compiled from the def *) | Cst_non_recursive (* this is a disjunction of sufficient conditions for the existence of @@ -171,23 +235,26 @@ let hash_lit a = let cmp_cc_node a b = term_cmp_ a.n_term b.n_term -let cmp_cc_expl a b = +let rec cmp_exp a b = let toint = function - | CC_congruence _ -> 0 | CC_lit _ -> 1 - | CC_reduction -> 2 | CC_injectivity _ -> 3 - | CC_reduce_eq _ -> 5 + | E_congruence _ -> 0 | E_lit _ -> 1 + | E_reduction -> 2 | E_injectivity _ -> 3 + | E_reduce_eq _ -> 5 + | E_custom _ -> 6 in begin match a, b with - | CC_congruence (t1,t2), CC_congruence (u1,u2) -> + | E_congruence (t1,t2), E_congruence (u1,u2) -> CCOrd.(cmp_cc_node t1 u1 (cmp_cc_node, t2, u2)) - | CC_reduction, CC_reduction -> 0 - | CC_lit l1, CC_lit l2 -> cmp_lit l1 l2 - | CC_injectivity (t1,t2), CC_injectivity (u1,u2) -> + | E_reduction, E_reduction -> 0 + | E_lit l1, E_lit l2 -> cmp_lit l1 l2 + | E_injectivity (t1,t2), E_injectivity (u1,u2) -> CCOrd.(cmp_cc_node t1 u1 (cmp_cc_node, t2, u2)) - | CC_reduce_eq (t1, u1), CC_reduce_eq (t2,u2) -> + | E_reduce_eq (t1, u1), E_reduce_eq (t2,u2) -> CCOrd.(cmp_cc_node t1 t2 (cmp_cc_node, u1, u2)) - | CC_congruence _, _ | CC_lit _, _ | CC_reduction, _ - | CC_injectivity _, _ | CC_reduce_eq _, _ + | E_custom r1, E_custom r2 -> + CCOrd.(ID.compare r1.name r2.name (list cmp_exp, r1.args, r2.args)) + | E_congruence _, _ | E_lit _, _ | E_reduction, _ + | E_injectivity _, _ | E_reduce_eq _, _ | E_custom _, _ -> CCInt.compare (toint a)(toint b) end @@ -237,14 +304,15 @@ let pp_term_top ~ids out t = Fmt.fprintf out "(@[match %a@ (@[%a@])@])" pp t print_map (ID.Map.to_seq m) | Builtin (B_not t) -> Fmt.fprintf out "(@[not@ %a@])" pp t - | Builtin (B_and (a,b)) -> - Fmt.fprintf out "(@[and@ %a@ %a@])" pp a pp b - | Builtin (B_or (a,b)) -> - Fmt.fprintf out "(@[or@ %a@ %a@])" pp a pp b + | Builtin (B_and l) -> + Fmt.fprintf out "(@[and@ %a])" (Util.pp_list pp) l + | Builtin (B_or l) -> + Fmt.fprintf out "(@[or@ %a@])" (Util.pp_list pp) l | Builtin (B_imply (a,b)) -> - Fmt.fprintf out "(@[=>@ %a@ %a@])" pp a pp b + Fmt.fprintf out "(@[=>@ %a@ %a@])" (Util.pp_list pp) a pp b | Builtin (B_eq (a,b)) -> Fmt.fprintf out "(@[=@ %a@ %a@])" pp a pp b + | Custom {view; tc} -> tc.tc_t_pp pp out view and pp_id = if ids then ID.pp else ID.pp_name in @@ -263,12 +331,13 @@ let pp_lit out l = let pp_cc_node out n = pp_term out n.n_term -let pp_cc_explanation out (e:cc_explanation) = match e with - | CC_reduction -> Fmt.string out "reduction" - | CC_lit lit -> pp_lit out lit - | CC_congruence (a,b) -> +let pp_explanation out (e:explanation) = match e with + | E_reduction -> Fmt.string out "reduction" + | E_lit lit -> pp_lit out lit + | E_congruence (a,b) -> Format.fprintf out "(@[congruence@ %a@ %a@])" pp_cc_node a pp_cc_node b - | CC_injectivity (a,b) -> + | E_injectivity (a,b) -> Format.fprintf out "(@[injectivity@ %a@ %a@])" pp_cc_node a pp_cc_node b - | CC_reduce_eq (t, u) -> + | E_reduce_eq (t, u) -> Format.fprintf out "(@[reduce_eq@ %a@ %a@])" pp_cc_node t pp_cc_node u + | E_custom {name; args; pp} -> pp out (name,args) diff --git a/src/smt/Term.ml b/src/smt/Term.ml index 01932453..d5b776dd 100644 --- a/src/smt/Term.ml +++ b/src/smt/Term.ml @@ -1,5 +1,4 @@ -open CDCL open Solver_types type t = term @@ -62,9 +61,19 @@ let if_ st a b c = make st (Term_cell.if_ a b c) let not_ st t = make st (Term_cell.not_ t) -let and_ st a b = make st (Term_cell.and_ a b) -let or_ st a b = make st (Term_cell.or_ a b) -let imply st a b = make st (Term_cell.imply a b) +let and_l st = function + | [] -> true_ st + | [t] -> t + | l -> make st (Term_cell.and_ l) + +let or_l st = function + | [] -> false_ st + | [t] -> t + | l -> make st (Term_cell.or_ l) + +let and_ st a b = and_l st [a;b] +let or_ st a b = and_l st [a;b] +let imply st a b = match a with [] -> b | _ -> make st (Term_cell.imply a b) let eq st a b = make st (Term_cell.eq a b) let neq st a b = not_ st (eq st a b) let builtin st b = make st (Term_cell.builtin b) @@ -80,16 +89,6 @@ let abs t : t * bool = match t.term_cell with | Builtin (B_not t) -> t, false | _ -> t, true -let rec and_l st = function - | [] -> true_ st - | [t] -> t - | a :: l -> and_ st a (and_l st l) - -let or_l st = function - | [] -> false_ st - | [t] -> t - | a :: l -> List.fold_left (or_ st) a l - let fold_map_builtin (f:'a -> term -> 'a * term) (acc:'a) (b:t builtin): 'a * t builtin = let fold_binary acc a b = @@ -101,17 +100,18 @@ let fold_map_builtin | B_not t -> let acc, t' = f acc t in acc, B_not t' - | B_and (a,b) -> - let acc, a, b = fold_binary acc a b in - acc, B_and (a,b) - | B_or (a,b) -> - let acc, a, b = fold_binary acc a b in - acc, B_or (a, b) + | B_and l -> + let acc, l = CCList.fold_map f acc l in + acc, B_and l + | B_or l -> + let acc, l = CCList.fold_map f acc l in + acc, B_or l | B_eq (a,b) -> let acc, a, b = fold_binary acc a b in acc, B_eq (a, b) | B_imply (a,b) -> - let acc, a, b = fold_binary acc a b in + let acc, a = CCList.fold_map f acc a in + let acc, b = f acc b in acc, B_imply (a, b) let is_const t = match t.term_cell with @@ -124,10 +124,9 @@ let map_builtin f b = let builtin_to_seq b yield = match b with | B_not t -> yield t - | B_or (a,b) - | B_imply (a,b) + | B_or l | B_and l -> List.iter yield l + | B_imply (a,b) -> List.iter yield a; yield b | B_eq (a,b) -> yield a; yield b - | B_and (a,b) -> yield a; yield b module As_key = struct type t = term @@ -150,6 +149,7 @@ let to_seq t yield = aux t; ID.Map.iter (fun _ rhs -> aux rhs) m | Builtin b -> builtin_to_seq b aux + | Custom {view;tc} -> tc.tc_t_sub view aux in aux t @@ -181,12 +181,8 @@ let as_unif (t:term): unif_form = match t.term_cell with Unif_cstor (c,cstor,a) | _ -> Unif_none -let fpf = Format.fprintf - let pp = Solver_types.pp_term - - let dummy : t = { term_id= -1; term_ty=Ty.prop; diff --git a/src/smt/Term.mli b/src/smt/Term.mli index 481d435b..de3d99eb 100644 --- a/src/smt/Term.mli +++ b/src/smt/Term.mli @@ -25,7 +25,7 @@ val builtin : state -> t builtin -> t val and_ : state -> t -> t -> t val or_ : state -> t -> t -> t val not_ : state -> t -> t -val imply : state -> t -> t -> t +val imply : state -> t list -> t -> t val eq : state -> t -> t -> t val neq : state -> t -> t -> t val and_eager : state -> t -> t -> t (* evaluate left argument first *) diff --git a/src/smt/Term_cell.ml b/src/smt/Term_cell.ml index d07fadde..3e25b5c1 100644 --- a/src/smt/Term_cell.ml +++ b/src/smt/Term_cell.ml @@ -27,10 +27,11 @@ module Make_eq(A : ARG) = struct in Hash.combine3 8 (sub_hash u) hash_m | Builtin (B_not a) -> Hash.combine2 20 (sub_hash a) - | Builtin (B_and (t1,t2)) -> Hash.combine3 21 (sub_hash t1) (sub_hash t2) - | Builtin (B_or (t1,t2)) -> Hash.combine3 22 (sub_hash t1) (sub_hash t2) - | Builtin (B_imply (t1,t2)) -> Hash.combine3 23 (sub_hash t1) (sub_hash t2) + | Builtin (B_and l) -> Hash.combine2 21 (Hash.list sub_hash l) + | Builtin (B_or l) -> Hash.combine2 22 (Hash.list sub_hash l) + | Builtin (B_imply (l1,t2)) -> Hash.combine3 23 (Hash.list sub_hash l1) (sub_hash t2) | Builtin (B_eq (t1,t2)) -> Hash.combine3 24 (sub_hash t1) (sub_hash t2) + | Custom {view;tc} -> tc.tc_t_hash sub_hash view (* equality that relies on physical equality of subterms *) let equal (a:A.t term_cell) b : bool = match a, b with @@ -51,18 +52,21 @@ module Make_eq(A : ARG) = struct | Builtin b1, Builtin b2 -> begin match b1, b2 with | B_not a1, B_not a2 -> sub_eq a1 a2 - | B_and (a1,b1), B_and (a2,b2) - | B_or (a1,b1), B_or (a2,b2) - | B_eq (a1,b1), B_eq (a2,b2) - | B_imply (a1,b1), B_imply (a2,b2) -> sub_eq a1 a2 && sub_eq b1 b2 + | B_and l1, B_and l2 + | B_or l1, B_or l2 -> CCEqual.list sub_eq l1 l2 + | B_eq (a1,b1), B_eq (a2,b2) -> sub_eq a1 a2 && sub_eq b1 b2 + | B_imply (a1,b1), B_imply (a2,b2) -> CCEqual.list sub_eq a1 a2 && sub_eq b1 b2 | B_not _, _ | B_and _, _ | B_eq _, _ | B_or _, _ | B_imply _, _ -> false end + | Custom r1, Custom r2 -> + r1.tc.tc_t_equal sub_eq r1.view r2.view | True, _ | App_cst _, _ | If _, _ | Case _, _ | Builtin _, _ + | Custom _, _ -> false end[@@inline] @@ -90,24 +94,26 @@ let cstor_proj cstor i t = app_cst p (IArray.singleton t) let builtin b = + let mk_ x = Builtin x in (* normalize a bit *) - let b = match b with - | B_eq (a,b) when a.term_id > b.term_id -> B_eq (b,a) - | B_and (a,b) when a.term_id > b.term_id -> B_and (b,a) - | B_or (a,b) when a.term_id > b.term_id -> B_or (b,a) - | _ -> b - in - Builtin b + begin match b with + | B_imply ([], x) -> x.term_cell + | B_eq (a,b) when a.term_id = b.term_id -> true_ + | B_eq (a,b) when a.term_id > b.term_id -> mk_ @@ B_eq (b,a) + | _ -> mk_ b + end let not_ t = match t.term_cell with | Builtin (B_not t') -> t'.term_cell | _ -> builtin (B_not t) -let and_ a b = builtin (B_and (a,b)) -let or_ a b = builtin (B_or (a,b)) +let and_ l = builtin (B_and l) +let or_ l = builtin (B_or l) let imply a b = builtin (B_imply (a,b)) let eq a b = builtin (B_eq (a,b)) +let custom ~tc view = Custom {view;tc} + (* type of an application *) let rec app_ty_ ty l : Ty.t = match Ty.view ty, l with | _, [] -> ty @@ -132,6 +138,7 @@ let ty (t:t): Ty.t = match t with let _, rhs = ID.Map.choose m in rhs.term_ty | Builtin _ -> Ty.prop + | Custom {view;tc} -> tc.tc_t_ty (fun t -> t.term_ty) view module Tbl = CCHashtbl.Make(struct type t = term term_cell diff --git a/src/smt/Term_cell.mli b/src/smt/Term_cell.mli index bd2726ef..71d3c2a8 100644 --- a/src/smt/Term_cell.mli +++ b/src/smt/Term_cell.mli @@ -15,11 +15,12 @@ val cstor_proj : data_cstor -> int -> term -> t val case : term -> term ID.Map.t -> t val if_ : term -> term -> term -> t val builtin : term builtin -> t -val and_ : term -> term -> t -val or_ : term -> term -> t +val and_ : term list -> t +val or_ : term list -> t val not_ : term -> t -val imply : term -> term -> t +val imply : term list -> term -> t val eq : term -> term -> t +val custom : tc:term_view_tc -> term term_view_custom -> t val ty : t -> Ty.t (** Compute the type of this term cell. Not totally free *)