diff --git a/src/cc/Congruence_closure.ml b/src/cc/Congruence_closure.ml index 26c7a3b2..ffadcdb6 100644 --- a/src/cc/Congruence_closure.ml +++ b/src/cc/Congruence_closure.ml @@ -65,7 +65,9 @@ module Make(A: ARG) = struct (* atomic explanation in the congruence closure *) and explanation = | E_reduction (* by pure reduction, tautologically equal *) + | E_merge of node * node | E_merges of (node * node) list (* caused by these merges *) + | E_congruence of node * node (* caused by normal congruence *) | E_lit of lit (* because of this literal *) | E_lits of lit list (* because of this (true) conjunction *) (* TODO: congruence case (cheaper than "merges") *) @@ -159,17 +161,27 @@ module Make(A: ARG) = struct let pp out (e:explanation) = match e with | E_reduction -> Fmt.string out "reduction" | E_lit lit -> A.Lit.pp out lit + | E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2 | E_lits l -> CCFormat.Dump.list A.Lit.pp out l + | E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b | E_merges l -> Format.fprintf out "(@[merges@ %a@])" Fmt.(seq ~sep:(return "@ ") @@ within "[" "]" @@ hvbox @@ pair ~sep:(return " ~@ ") N.pp N.pp) (Sequence.of_list l) - let[@inline] mk_merges l : t = E_merges l - let[@inline] mk_lit l : t = E_lit l - let[@inline] mk_lits = function [x] -> mk_lit x | l -> E_lits l let mk_reduction : t = E_reduction + let[@inline] mk_congruence n1 n2 : t = E_congruence (n1,n2) + let[@inline] mk_merge a b : t = E_merge (a,b) + let[@inline] mk_merges = function + | [] -> mk_reduction + | [(a,b)] -> mk_merge a b + | l -> E_merges l + let[@inline] mk_lit l : t = E_lit l + let[@inline] mk_lits = function + | [] -> mk_reduction + | [x] -> mk_lit x + | l -> E_lits l end (** A signature is a shallow term shape where immediate subterms @@ -407,12 +419,32 @@ module Make(A: ARG) = struct Vec.clear cc.ps_queue; () - let decompose_explain cc (e:explanation): unit = + (* TODO: turn this into a fold? *) + (* decompose explanation [e] of why [n1 = n2] *) + let decompose_explain cc (e:explanation) : unit = Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e); begin match e with | E_reduction -> () + | E_congruence (n1, n2) -> + begin match n1.n_sig0, n2.n_sig0 with + | Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) -> + assert (Fun.equal f1 f2); + assert (List.length a1 = List.length a2); + List.iter2 (ps_add_obligation cc) a1 a2; + | Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) -> + assert (List.length a1 = List.length a2); + ps_add_obligation cc f1 f2; + List.iter2 (ps_add_obligation cc) a1 a2; + | Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) -> + ps_add_obligation cc a1 a2; + ps_add_obligation cc b1 b2; + ps_add_obligation cc c1 c2; + | _ -> + assert false + end | E_lit lit -> ps_add_lit cc lit | E_lits l -> List.iter (ps_add_lit cc) l + | E_merge (a,b) -> ps_add_obligation cc a b | E_merges l -> (* need to explain each merge in [l] *) List.iter (fun (t,u) -> ps_add_obligation cc t u) l @@ -420,15 +452,17 @@ module Make(A: ARG) = struct (* explain why [a = parent_a], where [a -> ... -> parent_a] in the proof forest *) - let rec explain_along_path ps (a:node) (parent_a:node) : unit = - if a!=parent_a then ( - match a.n_expl with + let explain_along_path ps (a:node) (parent_a:node) : unit = + let rec aux n = + if n != parent_a then ( + match n.n_expl with | FL_none -> assert false - | FL_some {next=next_a; expl=e_a_b} -> - decompose_explain ps e_a_b; - (* now prove [next_a = parent_a] *) - explain_along_path ps next_a parent_a - ) + | FL_some {next=next_n; expl=expl} -> + decompose_explain ps expl; + (* now prove [next_n = parent_a] *) + aux next_n + ) + in aux a (* find explanation *) let explain_loop (cc : t) : lit list = @@ -569,7 +603,7 @@ module Make(A: ARG) = struct | Some (Eq (a,b)) -> (* if [a=b] is now true, merge [(a=b)] and [true] *) if same_class a b then ( - let expl = Expl.mk_merges [(a,b)] in + let expl = Expl.mk_merge a b in push_combine cc n (true_ cc) expl ) | Some s0 -> @@ -584,28 +618,11 @@ module Make(A: ARG) = struct (* [t1] and [t2] must be applications of the same symbol to arguments that are pairwise equal *) assert (n != u); - let expl = - match n.n_sig0, u.n_sig0 with - | Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) -> - assert (Fun.equal f1 f2); - assert (List.length a1 = List.length a2); - (* TODO: just use "congruence" as explanation *) - Expl.mk_merges @@ List.combine a1 a2 - | Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) -> - assert (List.length a1 = List.length a2); - (* TODO: just use "congruence" as explanation *) - Expl.mk_merges @@ (f1,f2)::List.combine a1 a2 - | Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) -> - Expl.mk_merges @@ [a1,a2; b1,b2; c1,c2] - | _ - -> assert false - in + let expl = Expl.mk_congruence n u in push_combine cc n u expl - (* FIXME: when to actually evaluate? - eval_pending cc; - *) end + (* TODO: remove, once we have moved distinct to a theory *) and[@inline] task_combine_ cc acts = function | CT_merge (a,b,e_ab) -> task_merge_ cc acts a b e_ab | CT_distinct (l,tag,e) -> task_distinct_ cc acts l tag e diff --git a/src/cc/Congruence_closure_intf.ml b/src/cc/Congruence_closure_intf.ml index 9f33287e..709d5f71 100644 --- a/src/cc/Congruence_closure_intf.ml +++ b/src/cc/Congruence_closure_intf.ml @@ -44,6 +44,7 @@ module type S = sig val pp : t Fmt.printer val is_root : t -> bool + (** Is the node a root (ie the representative of its class)? *) val iter_class : t -> t Sequence.t (** Traverse the congruence class. @@ -68,6 +69,8 @@ module type S = sig module Expl : sig type t val pp : t Fmt.printer + + (* TODO: expose constructors for micro theories to use *) end type node = N.t @@ -80,6 +83,9 @@ module type S = sig type conflict = lit list + (* TODO: notion of micro theory, parametrized by [on_backtrack, find, etc] + and with callbacks for on_merge? *) + (* TODO micro theories as parameters *) val create : ?on_merge:(repr -> repr -> explanation -> unit) ->