diff --git a/src/sat/Internal.ml b/src/sat/Internal.ml index fd14ddf0..0db0df99 100644 --- a/src/sat/Internal.ml +++ b/src/sat/Internal.ml @@ -728,7 +728,6 @@ module Make (Th : Theory_intf.S) = struct end let[@inline] theory st = Lazy.force st.th - let[@inline] nb_clauses st = Vec.size st.clauses let[@inline] decision_level st = Vec.size st.elt_levels let[@inline] base_level st = if decision_level st > 0 then 1 else 0 (* first level=assumptions *) diff --git a/src/smt/Congruence_closure.ml b/src/smt/Congruence_closure.ml index 4bcc4701..dab6b42a 100644 --- a/src/smt/Congruence_closure.ml +++ b/src/smt/Congruence_closure.ml @@ -72,8 +72,8 @@ let[@inline] on_backtrack cc f : unit = let[@inline] is_root_ (n:node) : bool = n.n_root == n let[@inline] size_ (r:repr) = - assert (is_root_ (r:>node)); - Bag.size (r :> node).n_parents + assert (is_root_ r); + Bag.size r.n_parents (* check if [t] is in the congruence closure. Invariant: [in_cc t => in_cc u, forall u subterm t] *) @@ -90,9 +90,9 @@ let rec find_rec cc (n:node) : repr = let old_root = n.n_root in let root = find_rec cc old_root in (* path compression *) - if (root :> node) != old_root then ( + if root != old_root then ( on_backtrack cc (fun () -> n.n_root <- old_root); - n.n_root <- (root :> node); + n.n_root <- root; ); root ) @@ -154,8 +154,11 @@ let is_done (cc:t): bool = Vec.is_empty cc.combine let push_pending cc t : unit = - Log.debugf 5 (fun k->k "(@[cc.push_pending@ %a@])" Equiv_class.pp t); - Vec.push cc.pending t + if not @@ Equiv_class.get_field Equiv_class.field_is_pending t then ( + Log.debugf 5 (fun k->k "(@[cc.push_pending@ %a@])" Equiv_class.pp t); + Equiv_class.set_field Equiv_class.field_is_pending true t; + Vec.push cc.pending t + ) let push_combine cc t u e : unit = Log.debugf 5 @@ -322,11 +325,13 @@ let rec update_pending (cc:t): unit = might have changed *) while not (Vec.is_empty cc.pending) do let n = Vec.pop_last cc.pending in + Equiv_class.set_field Equiv_class.field_is_pending false n; (* check if some parent collided *) begin match find_by_signature cc n.n_term with | None -> (* add to the signature table [sig(n) --> n] *) add_signature cc n.n_term n + | Some u when n == u -> () | Some u -> (* must combine [t] with [r] *) if not @@ same_class cc n u then ( @@ -337,8 +342,7 @@ let rec update_pending (cc:t): unit = | App_cst (f1, a1), App_cst (f2, a2) -> assert (Cst.equal f1 f2); assert (IArray.length a1 = IArray.length a2); - Explanation.mk_merges @@ - IArray.map2 (fun u1 u2 -> add_ cc u1, add_ cc u2) a1 a2 + Explanation.mk_merges @@ IArray.map2 (fun u1 u2 -> add_ cc u1, add_ cc u2) a1 a2 | If _, _ | App_cst _, _ | Bool _, _ -> assert false in @@ -386,14 +390,12 @@ and update_combine cc = in (* remove [ra.parents] from signature, put them into [st.pending] *) begin - Bag.to_seq (r_from:>node).n_parents + Bag.to_seq r_from.n_parents |> Sequence.iter (fun parent -> push_pending cc parent) end; (* perform [union ra rb] *) begin - let r_from = (r_from :> node) in - let r_into = (r_into :> node) in let r_into_old_parents = r_into.n_parents in let r_into_old_tags = r_into.n_tags in on_backtrack cc @@ -485,26 +487,26 @@ let add_seq cc seq = (* to do after backtracking: reset task lists *) let reset_tasks cc : unit = + Vec.iter (Equiv_class.set_field Equiv_class.field_is_pending false) cc.pending; Vec.clear cc.pending; Vec.clear cc.combine; () (* assert that this boolean literal holds *) -let assert_lit cc lit : unit = match Lit.view lit with - | Lit_fresh _ -> () - | Lit_atom t -> - assert (Ty.is_prop t.term_ty); - Log.debugf 5 (fun k->k "(@[cc.assert_lit@ %a@])" Lit.pp lit); - let sign = Lit.sign lit in - (* equate t and true/false *) - let rhs = if sign then true_ cc else false_ cc in - let n = add cc t in - (* TODO: ensure that this is O(1). - basically, just have [n] point to true/false and thus acquire - the corresponding value, so its superterms (like [ite]) can evaluate - properly *) - push_combine cc n rhs (E_lit lit); - () +let assert_lit cc lit : unit = + let t = Lit.view lit in + assert (Ty.is_prop t.term_ty); + Log.debugf 5 (fun k->k "(@[cc.assert_lit@ %a@])" Lit.pp lit); + let sign = Lit.sign lit in + (* equate t and true/false *) + let rhs = if sign then true_ cc else false_ cc in + let n = add cc t in + (* TODO: ensure that this is O(1). + basically, just have [n] point to true/false and thus acquire + the corresponding value, so its superterms (like [ite]) can evaluate + properly *) + push_combine cc n rhs (E_lit lit); + () let assert_eq cc (t:term) (u:term) expl : unit = let n1 = add cc t in diff --git a/src/smt/Equiv_class.ml b/src/smt/Equiv_class.ml index 1f8d3c4e..13f73a3d 100644 --- a/src/smt/Equiv_class.ml +++ b/src/smt/Equiv_class.ml @@ -5,6 +5,7 @@ type t = cc_node type payload = cc_node_payload = .. let field_is_active = Node_bits.mk_field() +let field_is_pending = Node_bits.mk_field() let () = Node_bits.freeze() let[@inline] equal (n1:t) n2 = n1==n2 @@ -59,6 +60,9 @@ let payload_pred ~f:p n = | l -> List.exists p l end +let[@inline] get_field f t = Node_bits.get f t.n_bits +let[@inline] set_field f b t = t.n_bits <- Node_bits.set f b t.n_bits + module Tbl = CCHashtbl.Make(struct type t = cc_node let equal = equal diff --git a/src/smt/Equiv_class.mli b/src/smt/Equiv_class.mli index 3f46b88b..9885094f 100644 --- a/src/smt/Equiv_class.mli +++ b/src/smt/Equiv_class.mli @@ -27,6 +27,9 @@ val field_is_active : Node_bits.field (** The term is needed for evaluation. We must try to evaluate it or to find a value for it using the theory *) +val field_is_pending : Node_bits.field +(** true iff the node is in the [cc.pending] queue *) + (** {2 basics} *) val term : t -> term @@ -49,6 +52,9 @@ val set_payload : ?can_erase:(payload -> bool) -> t -> payload -> unit @param can_erase if provided, checks whether an existing value is to be replaced instead of adding a new entry *) +val get_field : Node_bits.field -> t -> bool +val set_field : Node_bits.field -> bool -> t -> unit + module Tbl : CCHashtbl.S with type key = t (**/**) diff --git a/src/smt/Lit.ml b/src/smt/Lit.ml index 11719c27..c2cae168 100644 --- a/src/smt/Lit.ml +++ b/src/smt/Lit.ml @@ -2,48 +2,31 @@ open Solver_types type t = lit = { - lit_view : lit_view; + lit_term: term; lit_sign : bool } -and view = lit_view = - | Lit_fresh of ID.t - | Lit_atom of term - let neg l = {l with lit_sign=not l.lit_sign} -let sign t = t.lit_sign -let view (t:t): lit_view = t.lit_view +let[@inline] sign t = t.lit_sign +let[@inline] view (t:t): term = t.lit_term -let abs t: t = {t with lit_sign=true} +let[@inline] abs t: t = {t with lit_sign=true} -let make ~sign v = {lit_sign=sign; lit_view=v} +let make ~sign t = {lit_sign=sign; lit_term=t} -(* assume the ID is fresh *) -let fresh_with id = make ~sign:true (Lit_fresh id) - -(* fresh boolean literal *) -let fresh: unit -> t = - let n = ref 0 in - fun () -> - let id = ID.makef "#fresh_%d" !n in - incr n; - make ~sign:true (Lit_fresh id) - -let dummy = fresh() +let dummy = make ~sign:true Term.dummy let atom ?(sign=true) (t:term) : t = let t, sign' = Term.abs t in let sign = if not sign' then not sign else sign in - make ~sign (Lit_atom t) + make ~sign t -let as_atom (lit:t) : (term * bool) option = match lit.lit_view with - | Lit_atom t -> Some (t, lit.lit_sign) - | _ -> None +let[@inline] as_atom (lit:t) = lit.lit_term, lit.lit_sign let hash = hash_lit let compare = cmp_lit -let equal a b = compare a b = 0 +let[@inline] equal a b = compare a b = 0 let pp = pp_lit let print = pp diff --git a/src/smt/Lit.mli b/src/smt/Lit.mli index 6aac0a76..055fa373 100644 --- a/src/smt/Lit.mli +++ b/src/smt/Lit.mli @@ -3,21 +3,15 @@ open Solver_types type t = lit = { - lit_view : lit_view; + lit_term: term; lit_sign : bool } -and view = lit_view = - | Lit_fresh of ID.t - | Lit_atom of term - val neg : t -> t val abs : t -> t val sign : t -> bool -val view : t -> lit_view -val as_atom : t -> (term * bool) option -val fresh_with : ID.t -> t -val fresh : unit -> t +val view : t -> term +val as_atom : t -> term * bool val dummy : t val atom : ?sign:bool -> term -> t val hash : t -> int diff --git a/src/smt/Solver_types.ml b/src/smt/Solver_types.ml index 0f7ce5f1..ec311bdb 100644 --- a/src/smt/Solver_types.ml +++ b/src/smt/Solver_types.ml @@ -55,14 +55,10 @@ and explanation = (* boolean literal *) and lit = { - lit_view: lit_view; + lit_term: term; lit_sign: bool; } -and lit_view = - | Lit_fresh of ID.t (* fresh literals *) - | Lit_atom of term - and cst = { cst_id: ID.t; cst_view: cst_view; @@ -143,25 +139,13 @@ let[@inline] term_cmp_ a b = CCInt.compare a.term_id b.term_id let cmp_lit a b = let c = CCBool.compare a.lit_sign b.lit_sign in if c<>0 then c - else ( - let int_of_cell_ = function - | Lit_fresh _ -> 0 - | Lit_atom _ -> 1 - in - match a.lit_view, b.lit_view with - | Lit_fresh i1, Lit_fresh i2 -> ID.compare i1 i2 - | Lit_atom t1, Lit_atom t2 -> term_cmp_ t1 t2 - | Lit_fresh _, _ | Lit_atom _, _ - -> CCInt.compare (int_of_cell_ a.lit_view) (int_of_cell_ b.lit_view) - ) + else term_cmp_ a.lit_term b.lit_term let cst_compare a b = ID.compare a.cst_id b.cst_id let hash_lit a = let sign = a.lit_sign in - match a.lit_view with - | Lit_fresh i -> Hash.combine3 1 (Hash.bool sign) (ID.hash i) - | Lit_atom t -> Hash.combine3 2 (Hash.bool sign) (term_hash_ t) + Hash.combine3 2 (Hash.bool sign) (term_hash_ a.lit_term) let cmp_cc_node a b = term_cmp_ a.n_term b.n_term @@ -236,12 +220,8 @@ let pp_term = pp_term_top ~ids:false let pp_term_view = pp_term_view ~pp_id:ID.pp_name ~pp_t:pp_term let pp_lit out l = - let pp_lit_view out = function - | Lit_fresh i -> Format.fprintf out "#%a" ID.pp i - | Lit_atom t -> pp_term out t - in - if l.lit_sign then pp_lit_view out l.lit_view - else Format.fprintf out "(@[@<1>¬@ %a@])" pp_lit_view l.lit_view + if l.lit_sign then pp_term out l.lit_term + else Format.fprintf out "(@[@<1>¬@ %a@])" pp_term l.lit_term let pp_cc_node out n = pp_term out n.n_term diff --git a/src/smt/Theory_combine.ml b/src/smt/Theory_combine.ml index 83076f71..54b18f5d 100644 --- a/src/smt/Theory_combine.ml +++ b/src/smt/Theory_combine.ml @@ -48,10 +48,9 @@ let assume_lit (self:t) (lit:Lit.t) : unit = (fun k->k "(@[<1>@{th_combine.assume_lit@}@ @[%a@]@])" Lit.pp lit); (* check consistency first *) begin match Lit.view lit with - | Lit_fresh _ -> () - | Lit_atom {term_view=Bool true; _} -> () - | Lit_atom {term_view=Bool false; _} -> () - | Lit_atom _ -> + | {term_view=Bool true; _} -> () + | {term_view=Bool false; _} -> () + | _ -> (* transmit to theories. *) C_clos.assert_lit (cc self) lit; theories self (fun (module Th) -> Th.on_assert Th.state lit); @@ -98,11 +97,9 @@ let assume_real (self:t) (slice:Lit.t Sat_solver.slice_actions) = ) let add_formula (self:t) (lit:Lit.t) = - match Lit.view lit with - | Lit_atom t -> - let lazy cc = self.cc in - ignore (C_clos.add cc t : cc_node) - | Lit_fresh _ -> () + let t = Lit.view lit in + let lazy cc = self.cc in + ignore (C_clos.add cc t : cc_node) (* propagation from the bool solver *) let assume (self:t) (slice:_ Sat_solver.slice_actions) = diff --git a/src/smt/th_bool/Sidekick_th_bool.ml b/src/smt/th_bool/Sidekick_th_bool.ml index b7d255b7..268baf0f 100644 --- a/src/smt/th_bool/Sidekick_th_bool.ml +++ b/src/smt/th_bool/Sidekick_th_bool.ml @@ -236,13 +236,11 @@ let tseitin (self:t) (lit:Lit.t) (lit_t:term) (v:term view) : unit = ) let on_assert (self:t) (lit:Lit.t) = - match Lit.view lit with - | Lit.Lit_atom t -> - begin match view t with - | B_atom _ -> () - | v -> tseitin self lit t v - end - | _ -> () + let t = Lit.view lit in + begin match view t with + | B_atom _ -> () + | v -> tseitin self lit t v + end let final_check _ _ : unit = () diff --git a/src/smtlib/Process.ml b/src/smtlib/Process.ml index 170c7904..384f7442 100644 --- a/src/smtlib/Process.ml +++ b/src/smtlib/Process.ml @@ -228,10 +228,8 @@ let check_smt_model (solver:Solver.Sat_solver.t) (hyps:_ Vec.t) (m:Model.t) : un let is_true = S.Atom.is_true a in let is_false = S.Atom.is_true (S.Atom.neg a) in let sat_value = if is_true then Some true else if is_false then Some false else None in - begin match Lit.as_atom lit with - | None -> assert false - | Some (t, sign) -> - match Model.eval m t with + let t, sign = Lit.as_atom lit in + begin match Model.eval m t with | Some (V_bool b) -> let b = if sign then b else not b in if (is_true || is_false) && ((b && is_false) || (not b && is_true)) then (