From edeb28c8ad16b728f32f1ce0ac1d67f800dd2a60 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Fri, 25 May 2018 19:35:33 -0500 Subject: [PATCH] refactor(smt): use list of lits as explanations for propagations --- src/smt/Congruence_closure.ml | 36 +++++++++++------------------ src/smt/Congruence_closure.mli | 10 ++++---- src/smt/Explanation.ml | 1 + src/smt/Solver.ml | 6 ++--- src/smt/Solver_types.ml | 8 +++++-- src/smt/Theory.ml | 6 ++--- src/smt/Theory_combine.ml | 8 ++----- src/smt/th_bool/Sidekick_th_bool.ml | 13 ++++------- 8 files changed, 39 insertions(+), 49 deletions(-) diff --git a/src/smt/Congruence_closure.ml b/src/smt/Congruence_closure.ml index 550e4f52..ff13ce81 100644 --- a/src/smt/Congruence_closure.ml +++ b/src/smt/Congruence_closure.ml @@ -5,6 +5,7 @@ open Solver_types type node = Equiv_class.t type repr = Equiv_class.t +type conflict = Theory.conflict (** A signature is a shallow term shape where immediate subterms are representative *) @@ -25,7 +26,7 @@ type actions = { on_merge:repr -> repr -> explanation -> unit; (** Call this when two classes are merged *) - raise_conflict: 'a. Lit.Set.t -> 'a; + raise_conflict: 'a. conflict -> 'a; (** Report a conflict *) propagate: Lit.t -> Lit.t list -> unit; @@ -170,14 +171,9 @@ let push_combine cc t u e : unit = Equiv_class.pp t Equiv_class.pp u Explanation.pp e); Vec.push cc.combine (t,u,e) -let push_propagation cc (lit:lit) (expl:explanation Bag.t): unit = - Log.debugf 5 - (fun k->k "(@[cc.push_propagate@ %a@ :expl (@[%a@])@])" - Lit.pp lit (Util.pp_seq Explanation.pp) @@ Bag.to_seq expl); - cc.acts.propagate lit expl - -let[@inline] union cc (a:node) (b:node) (e:explanation): unit = +let[@inline] union cc (a:node) (b:node) (e:Lit.t list): unit = if not (same_class cc a b) then ( + let e = Explanation.E_lits e in push_combine cc a b e; (* start by merging [a=b] *) ) @@ -195,7 +191,7 @@ let rec reroot_expl (cc:t) (n:node): unit = n.n_expl <- E_none; end -let[@inline] raise_conflict (cc:t) (e:Lit.Set.t): _ = +let[@inline] raise_conflict (cc:t) (e:conflict): _ = cc.acts.raise_conflict e let[@inline] all_classes cc : repr Sequence.t = @@ -251,6 +247,7 @@ let rec decompose_explain cc (e:explanation): unit = begin match e with | E_reduction -> () | E_lit lit -> ps_add_lit cc lit + | E_lits l -> List.iter (ps_add_lit cc) l | E_custom {args;_} -> (* decompose sub-expls *) List.iter (decompose_explain cc) args @@ -350,12 +347,6 @@ let add_tag_n cc (n:node) (tag:int) (expl:explanation) : unit = n.n_tags <- Util.Int_map.add tag (n,expl) n.n_tags; ) -(* conflict because [expl => t1 ≠ t2] but they are the same *) -let distinct_conflict cc (t1 : term) (t2: term) (expl:explanation Bag.t) : 'a = - let lits = explain_unfold_bag cc expl in - let lits = explain_eq_t ~init:lits cc t1 t2 in - raise_conflict cc lits - (* main CC algo: add terms from [pending] to the signature table, check for collisions *) let rec update_pending (cc:t): unit = @@ -418,7 +409,7 @@ and update_combine cc = let lits = explain_unfold ~init:lits cc e_ab in let lits = explain_eq_n ~init:lits cc a n1 in let lits = explain_eq_n ~init:lits cc b n2 in - raise_conflict cc lits) + raise_conflict cc @@ Lit.Set.elements lits) ra.n_tags rb.n_tags in (* solve the equation, if both [ra] and [rb] are semantic. @@ -438,9 +429,8 @@ and update_combine cc = | Solve_fail {expl} -> Log.debugf 5 (fun k->k "(@[solve-fail@ (@[= %a %a@])@ :expl %a@])" - Term.pp t_a Term.pp t_b Explanation.pp expl); - let lits = explain_unfold cc expl in - raise_conflict cc lits + Term.pp t_a Term.pp t_b (CCFormat.Dump.list Lit.pp) expl); + raise_conflict cc expl end | _ -> assert false ); @@ -578,7 +568,7 @@ let assert_eq cc (t:term) (u:term) expl : unit = union cc n1 n2 expl ) -let assert_distinct cc (l:term list) ~neq expl : unit = +let assert_distinct cc (l:term list) ~neq (lit:Lit.t) : unit = assert (match l with[] | [_] -> false | _ -> true); let tag = Term.id neq in Log.debugf 5 @@ -592,12 +582,12 @@ let assert_distinct cc (l:term list) ~neq expl : unit = | Some ((t1,_n1),(t2,_n2)) -> (* two classes are already equal *) Log.debugf 5 (fun k->k "(@[cc.assert_distinct.conflict@ %a = %a@])" Term.pp t1 Term.pp t2); - let lits = explain_unfold cc expl in + let lits = Lit.Set.singleton lit in let lits = explain_eq_t ~init:lits cc t1 t2 in - raise_conflict cc lits + raise_conflict cc @@ Lit.Set.to_list lits | None -> (* put a tag on all equivalence classes, that will make their merge fail *) - List.iter (fun (_,n) -> add_tag_n cc n tag expl) l + List.iter (fun (_,n) -> add_tag_n cc n tag @@ Explanation.lit lit) l end (* handling "distinct" constraints *) diff --git a/src/smt/Congruence_closure.mli b/src/smt/Congruence_closure.mli index 3be07b6a..cac49888 100644 --- a/src/smt/Congruence_closure.mli +++ b/src/smt/Congruence_closure.mli @@ -11,6 +11,8 @@ type node = Equiv_class.t type repr = Equiv_class.t (** Node that is currently a representative *) +type conflict = Theory.conflict + type actions = { on_backtrack:(unit -> unit) -> unit; (** Register a callback to be invoked upon backtracking below the current level *) @@ -18,7 +20,7 @@ type actions = { on_merge:repr -> repr -> explanation -> unit; (** Call this when two classes are merged *) - raise_conflict: 'a. Lit.Set.t -> 'a; + raise_conflict: 'a. conflict -> 'a; (** Report a conflict *) propagate: Lit.t -> Lit.t list -> unit; @@ -40,7 +42,7 @@ val find : t -> node -> repr val same_class : t -> node -> node -> bool (** Are these two classes the same in the current CC? *) -val union : t -> node -> node -> explanation -> unit +val union : t -> node -> node -> Lit.t list -> unit (** Merge the two equivalence classes. Will be undone on backtracking. *) val mem : t -> term -> bool @@ -60,9 +62,9 @@ val assert_lit : t -> Lit.t -> unit (** Given a literal, assume it in the congruence closure and propagate its consequences. Will be backtracked. *) -val assert_eq : t -> term -> term -> explanation -> unit +val assert_eq : t -> term -> term -> Lit.t list -> unit -val assert_distinct : t -> term list -> neq:term -> explanation -> unit +val assert_distinct : t -> term list -> neq:term -> Lit.t -> unit (** [assert_distinct l ~expl:u e] asserts all elements of [l] are distinct with explanation [e] precond: [u = distinct l] *) diff --git a/src/smt/Explanation.ml b/src/smt/Explanation.ml index 9a301cf2..91a2c649 100644 --- a/src/smt/Explanation.ml +++ b/src/smt/Explanation.ml @@ -4,6 +4,7 @@ open Solver_types type t = explanation = | E_reduction | E_lit of lit + | E_lits of lit list | E_congruence of cc_node * cc_node | E_injectivity of cc_node * cc_node | E_reduce_eq of cc_node * cc_node diff --git a/src/smt/Solver.ml b/src/smt/Solver.ml index fad49a9f..6b8ffb44 100644 --- a/src/smt/Solver.ml +++ b/src/smt/Solver.ml @@ -395,10 +395,10 @@ let assume (self:t) (c:Lit.t IArray.t) : unit = Sat_solver.add_clause ~permanent:true sat c let[@inline] assume_eq self t u expl : unit = - Congruence_closure.assert_eq (cc self) t u (E_lit expl) + Congruence_closure.assert_eq (cc self) t u [expl] -let[@inline] assume_distinct self l ~neq expl : unit = - Congruence_closure.assert_distinct (cc self) l (E_lit expl) ~neq +let[@inline] assume_distinct self l ~neq lit : unit = + Congruence_closure.assert_distinct (cc self) l lit ~neq let check_model (s:t) = Sat_solver.check_model s.solver diff --git a/src/smt/Solver_types.ml b/src/smt/Solver_types.ml index 3c0459f8..29fd2298 100644 --- a/src/smt/Solver_types.ml +++ b/src/smt/Solver_types.ml @@ -71,7 +71,7 @@ and solve_result = } (** Success, the two terms being equal is equivalent to the given substitution *) | Solve_fail of { - expl: explanation; + expl: lit list; } (** Failure, because of the given explanation. The two terms cannot be equal *) @@ -105,6 +105,7 @@ and explanation_forest_link = and explanation = | E_reduction (* by pure reduction, tautologically equal *) | E_lit of lit (* because of this literal *) + | E_lits of lit list (* because of this (true) conjunction *) | E_congruence of cc_node * cc_node (* these terms are congruent *) | E_injectivity of cc_node * cc_node (* injective function *) | E_reduce_eq of cc_node * cc_node (* reduce because those are equal by reduction *) @@ -232,19 +233,21 @@ let rec cmp_exp a b = | E_reduction -> 2 | E_injectivity _ -> 3 | E_reduce_eq _ -> 5 | E_custom _ -> 6 + | E_lits _ -> 7 in begin match a, b with | E_congruence (t1,t2), E_congruence (u1,u2) -> CCOrd.(cmp_cc_node t1 u1 (cmp_cc_node, t2, u2)) | E_reduction, E_reduction -> 0 | E_lit l1, E_lit l2 -> cmp_lit l1 l2 + | E_lits l1, E_lits l2 -> CCList.compare cmp_lit l1 l2 | E_injectivity (t1,t2), E_injectivity (u1,u2) -> CCOrd.(cmp_cc_node t1 u1 (cmp_cc_node, t2, u2)) | E_reduce_eq (t1, u1), E_reduce_eq (t2,u2) -> CCOrd.(cmp_cc_node t1 t2 (cmp_cc_node, u1, u2)) | E_custom r1, E_custom r2 -> CCOrd.(ID.compare r1.name r2.name (list cmp_exp, r1.args, r2.args)) - | E_congruence _, _ | E_lit _, _ | E_reduction, _ + | E_congruence _, _ | E_lit _, _ | E_reduction, _ | E_lits _, _ | E_injectivity _, _ | E_reduce_eq _, _ | E_custom _, _ -> CCInt.compare (toint a)(toint b) end @@ -317,6 +320,7 @@ let pp_cc_node out n = pp_term out n.n_term let pp_explanation out (e:explanation) = match e with | E_reduction -> Fmt.string out "reduction" | E_lit lit -> pp_lit out lit + | E_lits l -> CCFormat.Dump.list pp_lit out l | E_congruence (a,b) -> Format.fprintf out "(@[congruence@ %a@ %a@])" pp_cc_node a pp_cc_node b | E_injectivity (a,b) -> diff --git a/src/smt/Theory.ml b/src/smt/Theory.ml index 7c86cdad..90620f58 100644 --- a/src/smt/Theory.ml +++ b/src/smt/Theory.ml @@ -31,7 +31,7 @@ type state = State : { (** Unsatisfiable conjunction. Its negation will become a conflict clause *) -type conflict = Lit.Set.t +type conflict = Lit.t list (** Actions available to a theory during its lifetime *) type actions = { @@ -41,13 +41,13 @@ type actions = { raise_conflict: 'a. conflict -> 'a; (** Give a conflict clause to the solver *) - propagate_eq: Term.t -> Term.t -> Lit.Set.t -> unit; + propagate_eq: Term.t -> Term.t -> Lit.t list -> unit; (** Propagate an equality [t = u] because [e] *) propagate_distinct: Term.t list -> neq:Term.t -> Lit.t -> unit; (** Propagate a [distinct l] because [e] (where [e = neq] *) - propagate: Lit.t -> Lit.Set.t -> unit; + propagate: Lit.t -> Lit.t list -> unit; (** Propagate a boolean using a unit clause. [expl => lit] must be a theory lemma, that is, a T-tautology *) diff --git a/src/smt/Theory_combine.ml b/src/smt/Theory_combine.ml index 2c6b8bb9..64bcc862 100644 --- a/src/smt/Theory_combine.ml +++ b/src/smt/Theory_combine.ml @@ -15,8 +15,7 @@ module Form = Lit type formula = Lit.t type proof = Proof.t - -type conflict = Lit.Set.t +type conflict = Theory.conflict (* raise upon conflict *) exception Exn_conflict of conflict @@ -63,10 +62,7 @@ let cdcl_return_res (self:t) : _ Sat_solver.res = | None -> Sat_solver.Sat | Some lit_set -> - let conflict_clause = - Lit.Set.to_list lit_set - |> IArray.of_list_map Lit.neg - in + let conflict_clause = IArray.of_list_map Lit.neg lit_set in Sat_solver.Log.debugf 3 (fun k->k "(@[<1>conflict@ clause: %a@])" Theory.Clause.pp conflict_clause); diff --git a/src/smt/th_bool/Sidekick_th_bool.ml b/src/smt/th_bool/Sidekick_th_bool.ml index 2022de44..915fdc0f 100644 --- a/src/smt/th_bool/Sidekick_th_bool.ml +++ b/src/smt/th_bool/Sidekick_th_bool.ml @@ -255,7 +255,7 @@ let tseitin (self:t) (lit:Lit.t) (lit_t:term) (b:term builtin) : unit = | B_not _ -> assert false (* normalized *) | B_eq (t,u) -> if Lit.sign lit then ( - self.acts.Theory.propagate_eq t u (Lit.Set.singleton lit) + self.acts.Theory.propagate_eq t u [lit] ) else ( self.acts.Theory.propagate_distinct [t;u] ~neq:lit_t lit ) @@ -269,11 +269,10 @@ let tseitin (self:t) (lit:Lit.t) (lit_t:term) (b:term builtin) : unit = | B_and subs -> if Lit.sign lit then ( (* propagate [lit => subs_i] *) - let expl = Lit.Set.singleton lit in List.iter (fun sub -> let sublit = Lit.atom sub in - self.acts.Theory.propagate sublit expl) + self.acts.Theory.propagate sublit [lit]) subs ) else ( (* propagate [¬lit => ∨_i ¬ subs_i] *) @@ -287,11 +286,10 @@ let tseitin (self:t) (lit:Lit.t) (lit_t:term) (b:term builtin) : unit = self.acts.Theory.add_local_axiom (IArray.of_list c) ) else ( (* propagate [¬lit => ¬subs_i] *) - let expl = Lit.Set.singleton lit in List.iter (fun sub -> let sublit = Lit.atom ~sign:false sub in - self.acts.Theory.propagate sublit expl) + self.acts.Theory.propagate sublit [lit]) subs ) | B_imply (guard,concl) -> @@ -300,14 +298,13 @@ let tseitin (self:t) (lit:Lit.t) (lit_t:term) (b:term builtin) : unit = let c = Lit.atom concl :: Lit.neg lit :: List.map (Lit.atom ~sign:false) guard in self.acts.Theory.add_local_axiom (IArray.of_list c) ) else ( - let expl = Lit.Set.singleton lit in (* propagate [¬lit => ¬concl] *) - self.acts.Theory.propagate (Lit.atom ~sign:false concl) expl; + self.acts.Theory.propagate (Lit.atom ~sign:false concl) [lit]; (* propagate [¬lit => ∧_i guard_i] *) List.iter (fun sub -> let sublit = Lit.atom ~sign:true sub in - self.acts.Theory.propagate sublit expl) + self.acts.Theory.propagate sublit [lit]) guard )