refactor(cc): smaller explanations for congruence-based merges

This commit is contained in:
Simon Cruanes 2019-02-16 17:43:04 -06:00
parent 14255f94a7
commit 4d2bddc660
2 changed files with 55 additions and 32 deletions

View file

@ -65,7 +65,9 @@ module Make(A: ARG) = struct
(* atomic explanation in the congruence closure *)
and explanation =
| E_reduction (* by pure reduction, tautologically equal *)
| E_merge of node * node
| E_merges of (node * node) list (* caused by these merges *)
| E_congruence of node * node (* caused by normal congruence *)
| E_lit of lit (* because of this literal *)
| E_lits of lit list (* because of this (true) conjunction *)
(* TODO: congruence case (cheaper than "merges") *)
@ -159,17 +161,27 @@ module Make(A: ARG) = struct
let pp out (e:explanation) = match e with
| E_reduction -> Fmt.string out "reduction"
| E_lit lit -> A.Lit.pp out lit
| E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2
| E_lits l -> CCFormat.Dump.list A.Lit.pp out l
| E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b
| E_merges l ->
Format.fprintf out "(@[<hv1>merges@ %a@])"
Fmt.(seq ~sep:(return "@ ") @@ within "[" "]" @@ hvbox @@
pair ~sep:(return " ~@ ") N.pp N.pp)
(Sequence.of_list l)
let[@inline] mk_merges l : t = E_merges l
let[@inline] mk_lit l : t = E_lit l
let[@inline] mk_lits = function [x] -> mk_lit x | l -> E_lits l
let mk_reduction : t = E_reduction
let[@inline] mk_congruence n1 n2 : t = E_congruence (n1,n2)
let[@inline] mk_merge a b : t = E_merge (a,b)
let[@inline] mk_merges = function
| [] -> mk_reduction
| [(a,b)] -> mk_merge a b
| l -> E_merges l
let[@inline] mk_lit l : t = E_lit l
let[@inline] mk_lits = function
| [] -> mk_reduction
| [x] -> mk_lit x
| l -> E_lits l
end
(** A signature is a shallow term shape where immediate subterms
@ -407,12 +419,32 @@ module Make(A: ARG) = struct
Vec.clear cc.ps_queue;
()
let decompose_explain cc (e:explanation): unit =
(* TODO: turn this into a fold? *)
(* decompose explanation [e] of why [n1 = n2] *)
let decompose_explain cc (e:explanation) : unit =
Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e);
begin match e with
| E_reduction -> ()
| E_congruence (n1, n2) ->
begin match n1.n_sig0, n2.n_sig0 with
| Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) ->
assert (Fun.equal f1 f2);
assert (List.length a1 = List.length a2);
List.iter2 (ps_add_obligation cc) a1 a2;
| Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) ->
assert (List.length a1 = List.length a2);
ps_add_obligation cc f1 f2;
List.iter2 (ps_add_obligation cc) a1 a2;
| Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) ->
ps_add_obligation cc a1 a2;
ps_add_obligation cc b1 b2;
ps_add_obligation cc c1 c2;
| _ ->
assert false
end
| E_lit lit -> ps_add_lit cc lit
| E_lits l -> List.iter (ps_add_lit cc) l
| E_merge (a,b) -> ps_add_obligation cc a b
| E_merges l ->
(* need to explain each merge in [l] *)
List.iter (fun (t,u) -> ps_add_obligation cc t u) l
@ -420,15 +452,17 @@ module Make(A: ARG) = struct
(* explain why [a = parent_a], where [a -> ... -> parent_a] in the
proof forest *)
let rec explain_along_path ps (a:node) (parent_a:node) : unit =
if a!=parent_a then (
match a.n_expl with
let explain_along_path ps (a:node) (parent_a:node) : unit =
let rec aux n =
if n != parent_a then (
match n.n_expl with
| FL_none -> assert false
| FL_some {next=next_a; expl=e_a_b} ->
decompose_explain ps e_a_b;
(* now prove [next_a = parent_a] *)
explain_along_path ps next_a parent_a
)
| FL_some {next=next_n; expl=expl} ->
decompose_explain ps expl;
(* now prove [next_n = parent_a] *)
aux next_n
)
in aux a
(* find explanation *)
let explain_loop (cc : t) : lit list =
@ -569,7 +603,7 @@ module Make(A: ARG) = struct
| Some (Eq (a,b)) ->
(* if [a=b] is now true, merge [(a=b)] and [true] *)
if same_class a b then (
let expl = Expl.mk_merges [(a,b)] in
let expl = Expl.mk_merge a b in
push_combine cc n (true_ cc) expl
)
| Some s0 ->
@ -584,28 +618,11 @@ module Make(A: ARG) = struct
(* [t1] and [t2] must be applications of the same symbol to
arguments that are pairwise equal *)
assert (n != u);
let expl =
match n.n_sig0, u.n_sig0 with
| Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) ->
assert (Fun.equal f1 f2);
assert (List.length a1 = List.length a2);
(* TODO: just use "congruence" as explanation *)
Expl.mk_merges @@ List.combine a1 a2
| Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) ->
assert (List.length a1 = List.length a2);
(* TODO: just use "congruence" as explanation *)
Expl.mk_merges @@ (f1,f2)::List.combine a1 a2
| Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) ->
Expl.mk_merges @@ [a1,a2; b1,b2; c1,c2]
| _
-> assert false
in
let expl = Expl.mk_congruence n u in
push_combine cc n u expl
(* FIXME: when to actually evaluate?
eval_pending cc;
*)
end
(* TODO: remove, once we have moved distinct to a theory *)
and[@inline] task_combine_ cc acts = function
| CT_merge (a,b,e_ab) -> task_merge_ cc acts a b e_ab
| CT_distinct (l,tag,e) -> task_distinct_ cc acts l tag e

View file

@ -44,6 +44,7 @@ module type S = sig
val pp : t Fmt.printer
val is_root : t -> bool
(** Is the node a root (ie the representative of its class)? *)
val iter_class : t -> t Sequence.t
(** Traverse the congruence class.
@ -68,6 +69,8 @@ module type S = sig
module Expl : sig
type t
val pp : t Fmt.printer
(* TODO: expose constructors for micro theories to use *)
end
type node = N.t
@ -80,6 +83,9 @@ module type S = sig
type conflict = lit list
(* TODO: notion of micro theory, parametrized by [on_backtrack, find, etc]
and with callbacks for on_merge? *)
(* TODO micro theories as parameters *)
val create :
?on_merge:(repr -> repr -> explanation -> unit) ->