From 162fd37d9d1336e48a5bd9c18450eb76818f9213 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sun, 18 Jul 2021 20:49:15 -0400 Subject: [PATCH] wip: refactor --- src/core/Sidekick_core.ml | 5 +- src/sat/Solver.ml | 271 +++++++++++++++++++++++--------------- src/sat/Solver_intf.ml | 5 +- 3 files changed, 168 insertions(+), 113 deletions(-) diff --git a/src/core/Sidekick_core.ml b/src/core/Sidekick_core.ml index 05086320..e4642fb4 100644 --- a/src/core/Sidekick_core.ml +++ b/src/core/Sidekick_core.ml @@ -957,10 +957,11 @@ module type SOLVER = sig val equal : t -> t -> bool val hash : t -> int - val pp : t CCFormat.printer + + val pp : solver -> t CCFormat.printer + val formula : solver -> t -> lit val neg : t -> t - val formula : t -> lit val sign : t -> bool end diff --git a/src/sat/Solver.ml b/src/sat/Solver.ml index c341673b..f816d3c7 100644 --- a/src/sat/Solver.ml +++ b/src/sat/Solver.ml @@ -13,6 +13,24 @@ module type PLUGIN_CDCL_T = Solver_intf.PLUGIN_CDCL_T let invalid_argf fmt = Format.kasprintf (fun msg -> invalid_arg ("sidekick.sat: " ^ msg)) fmt +module type INT_ID = sig + type t = private int + val equal : t -> t -> bool + val compare : t -> t -> int + val hash : t -> int + val to_int : t -> int + val of_int_unsafe : int -> t +end + +module Mk_int_id() = struct + type t = int + let equal : t -> t -> bool = (=) + let compare : t -> t -> int = compare + let hash = CCHash.int + let[@inline] to_int i = i + let[@inline] of_int_unsafe i = i +end + module Make(Plugin : PLUGIN) = struct module Formula = Plugin.Formula @@ -21,27 +39,28 @@ module Make(Plugin : PLUGIN) type theory = Plugin.t type lemma = Plugin.proof - type var = { - vid : int; - pa : atom; - na : atom; - mutable v_fields : int; - mutable v_level : int; - mutable v_idx: int; (** position in heap *) - mutable v_weight : float; (** Weight (for the heap), tracking activity *) - mutable reason : reason option; - } + (* ### types ### *) - and atom = { - aid : int; - var : var; - neg : atom; - lit : formula; - mutable is_true : bool; - watched : clause Vec.t; - } + (* a boolean variable (positive int) *) + module Var : INT_ID = Mk_int_id() + type var = Var.t - and clause = { + (* a signed atom. +v is (v << 1), -v is (v<<1 | 1) *) + module Atom : sig + include INT_ID + val of_var : var -> t + val neg : t -> t + val sign : t -> bool + end = struct + include Mk_int_id() + let[@inline] of_var v = (v:var:>int) lsl 1 + let[@inline] neg i = (i lxor 1) + let[@inline] sign i = (i land 1) = 0 + end + type atom = Atom.t + + (* TODO: special clause allocator *) + type clause = { cid: int; atoms : atom array; mutable cpremise : premise; @@ -49,11 +68,6 @@ module Make(Plugin : PLUGIN) mutable flags: int; (* bitfield *) } - and reason = - | Decision - | Bcp of clause - | Bcp_lazy of clause lazy_t - (* TODO: remove, replace with user-provided proof trackng device? for pure SAT, [reason] is sufficient *) and premise = @@ -63,98 +77,124 @@ module Make(Plugin : PLUGIN) | History of clause list | Empty_premise - (* Constructors *) - module MF = Hashtbl.Make(Formula) + and reason = + | Decision + | Bcp of clause + | Bcp_lazy of clause lazy_t - (* state for variables. declared separately because it simplifies our + (* ### stores ### *) + + module Form_tbl = Hashtbl.Make(Formula) + + (* variable store. declared separately because it simplifies our life below, as it's required to construct new atoms/variables *) - type st = { - f_map: var MF.t; - vars: var Vec.t; - mutable cpt_mk_var: int; - mutable cpt_mk_clause: int; - } - - let create_st ?(size=`Big) () : st = - let size_map = match size with - | `Tiny -> 8 - | `Small -> 16 - | `Big -> 4096 - in - { f_map = MF.create size_map; - vars = Vec.create(); - cpt_mk_var = 0; - cpt_mk_clause = 0; + module Vars = struct + type t = { + of_form: var Form_tbl.t; + level: int Vec.t; + heap_idx: int Vec.t; + weight: float Vec.t; + reason: reason option Vec.t; + seen: Bitvec.t; + default_polarity: Bitvec.t; + mutable count : int; } - let nb_elt st = Vec.size st.vars - let get_elt st i = Vec.get st.vars i - let iter_elt st f = Vec.iter f st.vars + let create ?(size=`Big) () : t = + let size_map = match size with + | `Tiny -> 8 + | `Small -> 16 + | `Big -> 4096 + in + { of_form = Form_tbl.create size_map; + level = Vec.create(); + heap_idx = Vec.create(); + weight = Vec.create(); + reason = Vec.create(); + seen = Bitvec.create(); + default_polarity = Bitvec.create(); + count = 1; + } - let kind_of_clause c = match c.cpremise with - | Hyp _ -> "H" - | Lemma _ -> "T" - | Local -> "L" - | History _ -> "C" - | Empty_premise -> "" + (* allocate new variable *) + let alloc ?default_pol:(pol=true) self (form:formula) : var = + let {count; of_form; level; heap_idx; weight; + reason; seen; default_polarity; } = self in + let v_idx = count in + let var = Var.of_int_unsafe v_idx in + self.count <- 1 + count; + Form_tbl.add of_form form var; + Vec.push level (-1); + Vec.push heap_idx (-1); + Vec.push reason None; + Vec.push weight 0.; + Bitvec.ensure_size seen v_idx; + Bitvec.ensure_size default_polarity v_idx; + Bitvec.set default_polarity v_idx pol; + var - (* some boolean flags for variables, used as masks *) - let seen_var = 0b1 - let seen_pos = 0b10 - let seen_neg = 0b100 - let default_pol_true = 0b1000 + let[@inline] level self v = Vec.get self.level (v:var:>int) + let[@inline] reason self v = Vec.get self.reason (v:var:>int) + let[@inline] weight self v = Vec.get self.weight (v:var:>int) + let[@inline] set_weight self v w = Vec.set self.weight (v:var:>int) w + let[@inline] mark self v = Bitvec.set self.seen (v:var:>int) true + let[@inline] unmark self v = Bitvec.set self.seen (v:var:>int) false + let[@inline] marked self v = Bitvec.get self.seen (v:var:>int) + let[@inline] set_default_pol self v b = Bitvec.set self.default_polarity (v:var:>int) b + let[@inline] default_pol self v = Bitvec.get self.default_polarity (v:var:>int) + let[@inline] heap_idx self v = Vec.get self.heap_idx (v:var:>int) + let[@inline] set_idx self v i = Vec.set self.heap_idx (v:var:>int) i + end - module Var = struct - let[@inline] level v = v.v_level - let[@inline] pos v = v.pa - let[@inline] neg v = v.na - let[@inline] reason v = v.reason - let[@inline] weight v = v.v_weight - let[@inline] set_weight v w = v.v_weight <- w - let[@inline] mark v = v.v_fields <- v.v_fields lor seen_var - let[@inline] unmark v = v.v_fields <- v.v_fields land (lnot seen_var) - let[@inline] marked v = (v.v_fields land seen_var) <> 0 - let[@inline] set_default_pol_true v = v.v_fields <- v.v_fields lor default_pol_true - let[@inline] set_default_pol_false v = v.v_fields <- v.v_fields land (lnot default_pol_true) - let[@inline] default_pol v = (v.v_fields land default_pol_true) <> 0 - let[@inline] idx v = v.v_idx - let[@inline] set_idx v i = v.v_idx <- i + module Atoms = struct + type t = { + positive: bool; (* is this for positive atoms *) + is_true: Bitvec.t; + seen: Bitvec.t; + form: formula Vec.t; + (* TODO: store watches in clauses instead *) + watched: clause Vec.t Vec.t; + } - let make ?(default_pol=true) (st:st) (t:formula) : var * Solver_intf.negated = - let lit, negated = Formula.norm t in - try - MF.find st.f_map lit, negated - with Not_found -> - let cpt_double = st.cpt_mk_var lsl 1 in - let rec var = - { vid = st.cpt_mk_var; - pa = pa; - na = na; - v_fields = 0; - v_level = -1; - v_idx= -1; - v_weight = 0.; - reason = None; - } - and pa = - { var = var; - lit = lit; - watched = Vec.create(); - neg = na; - is_true = false; - aid = cpt_double (* aid = vid*2 *) } - and na = - { var = var; - lit = Formula.neg lit; - watched = Vec.create(); - neg = pa; - is_true = false; - aid = cpt_double + 1 (* aid = vid*2+1 *) } in - MF.add st.f_map lit var; - st.cpt_mk_var <- st.cpt_mk_var + 1; - if default_pol then set_default_pol_true var; - Vec.push st.vars var; - var, negated + let create ~positive () : t = + { positive; + is_true=Bitvec.create(); + form=Vec.create(); + watched=Vec.create(); + seen=Bitvec.create(); + } + + let of_var (self:t) (v:Var.t) : atom = + let a = Atom.of_var v in + if self.positive then a else Atom.neg a + + let alloc (self:t) (v:Var.t) (f:formula) : unit = + let {positive; is_true; seen; watched; form; } = self in + assert (Vec.size form = (v:var:>int) - 1); + Bitvec.ensure_size is_true (v:var:>int); + Bitvec.ensure_size seen (v:var:>int); + Vec.push form f; + () + end + + (* state holding variables *) + type st = { + vars: Vars.t; + pa: Atoms.t; + na: Atoms.t; + } + + (* create new variable *) + let mk_var (self:st) ?default_pol (t:formula) : var * Solver_intf.negated = + let form, negated = Formula.norm t in + try Form_tbl.find self.vars.Vars.of_form form, negated + with Not_found -> + let v = Vars.alloc ?default_pol self.vars form in + Atoms.alloc self.pa v form; + Atoms.alloc self.na v (Formula.neg form); + v, negated + + (* FIXME (* Marking helpers *) let[@inline] clear v = @@ -163,7 +203,20 @@ module Make(Plugin : PLUGIN) let[@inline] seen_both v = (seen_pos land v.v_fields <> 0) && (seen_neg land v.v_fields <> 0) - end + *) + + (* + let nb_elt st = Vec.size st.vars + let get_elt st i = Vec.get st.vars i + let iter_elt st f = Vec.iter f st.vars + *) + + let kind_of_clause c = match c.cpremise with + | Hyp _ -> "H" + | Lemma _ -> "T" + | Local -> "L" + | History _ -> "C" + | Empty_premise -> "" module Atom = struct type t = atom diff --git a/src/sat/Solver_intf.ml b/src/sat/Solver_intf.ml index afe2dfc6..e06d3a7e 100644 --- a/src/sat/Solver_intf.ml +++ b/src/sat/Solver_intf.ml @@ -335,8 +335,9 @@ module type S = sig val neg : t -> t val sign : t -> bool val abs : t -> t - val formula : t -> formula - val pp : t printer + + val formula : solver -> t -> formula + val pp : solver -> t printer end module Clause : sig