diff --git a/src/cdsat/TVar.ml b/src/cdsat/TVar.ml index a67c08d2..2629f108 100644 --- a/src/cdsat/TVar.ml +++ b/src/cdsat/TVar.ml @@ -25,9 +25,8 @@ type store = { value: Value.t option Vec.t; reason: reason Vec.t; theory_views: theory_view Vec.t; - watches: t Vec.t Vec.t; has_value: Bitvec.t; - new_vars: Vec_of.t; + new_vars: Vec_of.t; (* TODO: a recycle vec to reuse identifiers *) } (* create a new variable *) @@ -39,7 +38,6 @@ let new_var_ (self : store) ~term:(term_for_v : Term.t) ~theory_view () : t = term; level; value; - watches; reason; theory_views; has_value; @@ -52,7 +50,6 @@ let new_var_ (self : store) ~term:(term_for_v : Term.t) ~theory_view () : t = Vec.push value None; (* fake *) Vec.push reason dummy_reason_; - Vec.push watches (Vec.create ()); Vec.push theory_views theory_view; Bitvec.ensure_size has_value (v + 1); Bitvec.set has_value v false; @@ -65,6 +62,9 @@ let[@inline] get_of_term (self : store) (t : Term.t) : t option = let[@inline] has_value (self : store) (v : t) : bool = Bitvec.get self.has_value v +let[@inline] equal (a : t) (b : t) = a = b +let[@inline] compare (a : t) (b : t) = compare a b +let[@inline] hash (a : t) = CCHash.int a let[@inline] level (self : store) (v : t) : int = Veci.get self.level v let[@inline] value (self : store) (v : t) : _ option = Vec.get self.value v let[@inline] theory_view (self : store) (v : t) = Vec.get self.theory_views v @@ -77,10 +77,6 @@ let[@inline] bool_value (self : store) (v : t) : _ option = let[@inline] term (self : store) (v : t) : Term.t = Vec.get self.term v let[@inline] reason (self : store) (v : t) : reason = Vec.get self.reason v -let[@inline] watchers (self : store) (v : t) : t Vec.t = Vec.get self.watches v - -let[@inline] add_watcher (self : store) (v : t) ~watcher : unit = - Vec.push (watchers self v) watcher let assign (self : store) (v : t) ~value ~level ~reason : unit = Log.debugf 50 (fun k -> @@ -115,7 +111,6 @@ module Store = struct reason = Vec.create (); term = Vec.create (); level = Veci.create (); - watches = Vec.create (); value = Vec.create (); theory_views = Vec.create (); has_value = Bitvec.create (); @@ -133,6 +128,28 @@ module Tbl = Util.Int_tbl module Set = Util.Int_set module Map = Util.Int_map +module Dense_map (Elt : sig + type t + + val create : unit -> t +end) = +struct + type elt = Elt.t + type t = { v: elt Vec.t } [@@unboxed] + + let create () : t = { v = Vec.create () } + + let[@inline] get self v = + Vec.ensure_size_with self.v Elt.create (v + 1); + Vec.get self.v v + + let[@inline] set self v x = + Vec.ensure_size_with self.v Elt.create (v + 1); + Vec.set self.v v x + + let[@inline] iter self ~f = Vec.iteri self.v ~f +end + module Internal = struct let create (self : store) (t : Term.t) ~theory_view : t = assert (not @@ Term.Weak_map.mem self.of_term t); diff --git a/src/cdsat/TVar.mli b/src/cdsat/TVar.mli index b16235b3..9e46e27b 100644 --- a/src/cdsat/TVar.mli +++ b/src/cdsat/TVar.mli @@ -23,6 +23,8 @@ end module Vec_of : Vec_sig.S with type elt := t (** Vector of variables *) +include Sidekick_sigs.EQ_ORD_HASH with type t := t + type store = Store.t type reason = @@ -53,13 +55,6 @@ val theory_view : store -> t -> theory_view val assign : store -> t -> value:Value.t -> level:int -> reason:reason -> unit val unassign : store -> t -> unit -val watchers : store -> t -> t Vec.t -(** [watchers store t] is a vector of other variables watching [t], - generally updated via {!Watch1} and {!Watch2}. - These other variables are notified when [t] is assigned. *) - -val add_watcher : store -> t -> watcher:t -> unit - val pop_new_var : store -> t option (** Pop a new variable if any, or return [None] *) @@ -69,6 +64,23 @@ module Tbl : CCHashtbl.S with type key = t module Set : CCSet.S with type elt = t module Map : CCMap.S with type key = t +(** A map optimized for dense storage. + + This is useful when most variables have an entry in the map. *) +module Dense_map (Elt : sig + type t + + val create : unit -> t +end) : sig + type elt = Elt.t + type t + + val create : unit -> t + val get : t -> var -> elt + val set : t -> var -> elt -> unit + val iter : t -> f:(var -> elt -> unit) -> unit +end + (**/**) module Internal : sig diff --git a/src/cdsat/core.ml b/src/cdsat/core.ml index ae32acbd..ee85f22d 100644 --- a/src/cdsat/core.ml +++ b/src/cdsat/core.ml @@ -19,6 +19,19 @@ type pending_assignment = { reason: Reason.t; } +type plugin_id = int +(** Each plugin gets a unique identifier *) + +type plugin_event = .. + +type watch_request = + | Watch2 of TVar.t array * plugin_event + | Watch1 of TVar.t array * plugin_event + +module Watches = Watch_schemes.Make (struct + type t = plugin_id * plugin_event +end) + type t = { tst: Term.store; vst: TVar.store; @@ -29,6 +42,7 @@ type t = { term_to_var: Term_to_var.t; vars_to_decide: Vars_to_decide.t; pending_assignments: pending_assignment Vec.t; + watches: Watches.t; mutable last_res: Check_res.t option; proof_tracer: Proof.Tracer.t; n_conflicts: int Stat.counter; @@ -36,7 +50,7 @@ type t = { n_restarts: int Stat.counter; } -and plugin_action = t +and plugin_action = t * plugin_id (* FIXME: - add [on_add_var: TVar.t -> unit] @@ -51,12 +65,15 @@ and plugin_action = t and plugin = | P : { st: 'st; + id: plugin_id; name: string; push_level: 'st -> unit; pop_levels: 'st -> int -> unit; decide: 'st -> TVar.t -> Value.t option; - propagate: 'st -> plugin_action -> TVar.t -> Value.t -> unit; + on_assign: 'st -> plugin_action -> TVar.t -> Value.t -> unit; + on_event: 'st -> plugin_action -> unit:bool -> plugin_event -> unit; term_to_var_hooks: 'st -> Term_to_var.hook list; + on_add_var: 'st -> TVar.t -> watch_request list; } -> plugin @@ -71,6 +88,7 @@ let create ?(stats = Stat.create ()) ~arg tst vst ~proof_tracer () : t = pending_assignments = Vec.create (); term_to_var = Term_to_var.create vst; vars_to_decide = Vars_to_decide.create (); + watches = Watches.create vst; last_res = None; proof_tracer; n_restarts = Stat.mk_int stats "cdsat.restarts"; @@ -80,6 +98,7 @@ let create ?(stats = Stat.create ()) ~arg tst vst ~proof_tracer () : t = let[@inline] trail self = self.trail let[@inline] iter_plugins self ~f = Vec.iter ~f self.plugins +let[@inline] get_plugin (self : t) (id : plugin_id) = Vec.get self.plugins id let[@inline] tst self = self.tst let[@inline] vst self = self.vst let[@inline] last_res self = self.last_res @@ -88,15 +107,35 @@ let[@inline] last_res self = self.last_res module Plugin = struct type t = plugin - type builder = TVar.store -> t + type builder = id:plugin_id -> TVar.store -> t let[@inline] name (P p) = p.name - let make_builder ~name ~create ~push_level ~pop_levels ~decide ~propagate - ~term_to_var_hooks () : builder = - fun vst -> + type nonrec event = plugin_event = .. + + type nonrec watch_request = watch_request = + | Watch2 of TVar.t array * event + | Watch1 of TVar.t array * event + + let make_builder ~name ~create ~push_level ~pop_levels + ?(decide = fun _ _ -> None) ?(on_assign = fun _ _ _ _ -> ()) + ?(on_event = fun _ _ ~unit:_ _ -> ()) ?(on_add_var = fun _ _ -> []) + ?(term_to_var_hooks = fun _ -> []) () : builder = + fun ~id vst -> let st = create vst in - P { name; st; push_level; pop_levels; decide; propagate; term_to_var_hooks } + P + { + name; + id; + st; + push_level; + pop_levels; + decide; + on_assign; + on_event; + term_to_var_hooks; + on_add_var; + } end (* backtracking *) @@ -118,6 +157,7 @@ let pop_levels (self : t) n : unit = trail; plugins; term_to_var = _; + watches = _; vars_to_decide = _; pending_assignments; last_res = _; @@ -149,7 +189,8 @@ let add_term_to_var_hook self h = Term_to_var.add_hook self.term_to_var h (* plugins *) let add_plugin self (pb : Plugin.builder) : unit = - let (P p as plugin) = pb self.vst in + let id = Vec.size self.plugins in + let (P p as plugin) = pb ~id self.vst in Vec.push self.plugins plugin; List.iter (add_term_to_var_hook self) (p.term_to_var_hooks p.st) @@ -157,8 +198,8 @@ let add_plugin self (pb : Plugin.builder) : unit = let add_ty (_self : t) ~ty:_ : unit = () -(* Assign [v <- value] for [reason] at [level]. - This assignment is delayed. *) +(** Assign [v <- value] for [reason] at [level]. + This assignment is delayed. *) let assign (self : t) (v : TVar.t) ~(value : Value.t) ~level:v_level ~reason : unit = Log.debugf 50 (fun k -> @@ -173,27 +214,37 @@ let raise_conflict (c : Conflict.t) : 'a = raise (E_conflict c) (* add pending assignments to the trail. This might trigger a conflict in case an assignment contradicts an already existing assignment. *) -let perform_pending_assignments (self : t) : unit = - while not (Vec.is_empty self.pending_assignments) do - let { var = v; level = v_level; value; reason } = - Vec.pop_exn self.pending_assignments - in - match TVar.value self.vst v with - | None -> - TVar.assign self.vst v ~value ~level:v_level ~reason; - Trail.push_assignment self.trail v - | Some value' when Value.equal value value' -> () (* idempotent *) - | Some _value' -> - (* conflict should only occur on booleans since they're the only - propagation-able variables *) - assert (Term.has_ty_bool (TVar.term self.vst v)); - Log.debugf 0 (fun k -> - k "TODO: conflict (incompatible values for %a)" (TVar.pp self.vst) v); - raise_conflict - @@ Conflict.make self.vst ~lit:(TLit.make true v) ~propagate_reason:reason - () +let perform_pending_assignments_real_ (self : t) : unit = + while + match Vec.pop self.pending_assignments with + | None -> false + | Some { var = v; level = v_level; value; reason } -> + (match TVar.value self.vst v with + | None -> + (* assign [v], put it on the trail. Do not notify watchers yet. *) + TVar.assign self.vst v ~value ~level:v_level ~reason; + Trail.push_assignment self.trail v + | Some value' when Value.equal value value' -> () (* idempotent *) + | Some _value' -> + (* conflict should only occur on booleans since they're the only + propagation-able variables *) + assert (Term.has_ty_bool (TVar.term self.vst v)); + Log.debugf 0 (fun k -> + k "TODO: conflict (incompatible values for %a)" (TVar.pp self.vst) v); + raise_conflict + @@ Conflict.make self.vst ~lit:(TLit.make true v) + ~propagate_reason:reason ()); + true + do + () done +let[@inline] perform_pending_assignments (self : t) : unit = + if not (Vec.is_empty self.pending_assignments) then + perform_pending_assignments_real_ self + +(** Perform unit propagation in theories. Returns [Some c] if [c] + is a conflict detected during propagation. *) let propagate (self : t) : Conflict.t option = let@ () = Profile.with_ "cdsat.propagate" in try @@ -213,7 +264,16 @@ let propagate (self : t) : Conflict.t option = | None -> assert false in - iter_plugins self ~f:(fun (P p) -> p.propagate p.st self var value); + (* directly give assignment to plugins *) + iter_plugins self ~f:(fun (P p) -> + p.on_assign p.st (self, p.id) var value; + perform_pending_assignments self); + + (* notifier watchers *) + Watches.update self.watches var ~f:(fun ~unit (pl_id, ev) -> + let (P p) = get_plugin self pl_id in + p.on_event p.st (self, p.id) ~unit ev; + perform_pending_assignments self); (* move to next var *) Trail.set_head self.trail (Trail.head self.trail + 1) @@ -234,6 +294,7 @@ let make_sat_res (_self : t) : Check_res.sat_result = iter_true_lits = (fun _ -> assert false); } +(** Make a decision, or return [`Full_model] *) let rec decide (self : t) : [ `Decided | `Full_model ] = match Vars_to_decide.pop_next self.vars_to_decide with | None -> `Full_model @@ -260,6 +321,7 @@ let rec decide (self : t) : [ `Decided | `Full_model ] = `Decided ) +(** Solve satisfiability of the current set of assertions *) let solve ~on_exit ~on_progress ~should_stop ~assumptions (self : t) : Check_res.t = let@ () = Profile.with_ "cdsat.solve" in @@ -319,8 +381,26 @@ let solve ~on_exit ~on_progress ~should_stop ~assumptions (self : t) : module Plugin_action = struct type t = plugin_action - let[@inline] propagate (self : t) var value reason : unit = + let[@inline] propagate ((self, _) : t) var value reason : unit = assign self var ~value ~level:(Reason.level reason) ~reason - let term_to_var = term_to_var + let term_to_var (self, _) t : TVar.t = term_to_var self t + + let watch1 ((self, pl_id) : t) (vars : _ array) (ev : plugin_event) : unit = + let _h : Watches.handle = + Watches.watch1 self.watches vars (pl_id, ev) ~f:(fun ~unit (_, ev) -> + let (P p) = get_plugin self pl_id in + p.on_event p.st (self, pl_id) ~unit ev; + perform_pending_assignments self) + in + () + + let watch2 ((self, pl_id) : t) (vars : _ array) (ev : plugin_event) : unit = + let _h : Watches.handle = + Watches.watch2 self.watches vars (pl_id, ev) ~f:(fun ~unit (_, ev) -> + let (P p) = get_plugin self pl_id in + p.on_event p.st (self, pl_id) ~unit ev; + perform_pending_assignments self) + in + () end diff --git a/src/cdsat/core.mli b/src/cdsat/core.mli index 7f9caae8..badf5dbf 100644 --- a/src/cdsat/core.mli +++ b/src/cdsat/core.mli @@ -11,11 +11,25 @@ end (** {2 Plugins} *) +type plugin_event = .. + +(** Actions passed to plugins *) module Plugin_action : sig type t val propagate : t -> TVar.t -> Value.t -> Reason.t -> unit + (** Propagate given assignment *) + val term_to_var : t -> Term.t -> TVar.t + (** Convert a term to a variable *) + + val watch1 : t -> TVar.t array -> plugin_event -> unit + (** Create a watcher for the given set of variables, which will trigger + the event *) + + val watch2 : t -> TVar.t array -> plugin_event -> unit + (** Create a watcher for the given set of variables, which will trigger + the event *) end (** Core plugin *) @@ -25,14 +39,22 @@ module Plugin : sig val name : t -> string + type event = plugin_event = .. + + type watch_request = + | Watch2 of TVar.t array * event + | Watch1 of TVar.t array * event + val make_builder : name:string -> create:(TVar.store -> 'st) -> push_level:('st -> unit) -> pop_levels:('st -> int -> unit) -> - decide:('st -> TVar.t -> Value.t option) -> - propagate:('st -> Plugin_action.t -> TVar.t -> Value.t -> unit) -> - term_to_var_hooks:('st -> Term_to_var.hook list) -> + ?decide:('st -> TVar.t -> Value.t option) -> + ?on_assign:('st -> Plugin_action.t -> TVar.t -> Value.t -> unit) -> + ?on_event:('st -> Plugin_action.t -> unit:bool -> event -> unit) -> + ?on_add_var:('st -> TVar.t -> watch_request list) -> + ?term_to_var_hooks:('st -> Term_to_var.hook list) -> unit -> builder end diff --git a/src/cdsat/plugin_bool.ml b/src/cdsat/plugin_bool.ml index 1eb531d4..6c38184d 100644 --- a/src/cdsat/plugin_bool.ml +++ b/src/cdsat/plugin_bool.ml @@ -32,12 +32,13 @@ let decide (self : t) (v : TVar.t) : Value.t option = Some (Term.true_ self.tst) | _ -> None -let propagate (self : t) (act : Core.Plugin_action.t) (v : TVar.t) +let on_assign (self : t) (act : Core.Plugin_action.t) (v : TVar.t) (value : Value.t) : unit = Log.debugf 0 (fun k -> - k "(@[bool-plugin.propagate %a@])" (TVar.pp self.vst) v); + k "(@[bool-plugin.on-assign %a@])" (TVar.pp self.vst) v); () -(* TODO: BCP *) + +(* TODO: BCP via on_event *) let term_to_var_hooks (self : t) : _ list = let (module A) = self.arg in @@ -72,4 +73,4 @@ let builder ((module A : ARG) as arg) : Core.Plugin.builder = { arg; vst; tst } in Core.Plugin.make_builder ~name:"bool" ~create ~push_level ~pop_levels ~decide - ~propagate ~term_to_var_hooks () + ~on_assign ~term_to_var_hooks () diff --git a/src/cdsat/plugin_uninterpreted.ml b/src/cdsat/plugin_uninterpreted.ml index 759448c1..e169e06d 100644 --- a/src/cdsat/plugin_uninterpreted.ml +++ b/src/cdsat/plugin_uninterpreted.ml @@ -35,7 +35,8 @@ let pop_levels self n = Cong_tbl.pop_levels self.cong_table n (* let other theories decide, depending on the type *) let decide _ _ = None -let propagate (self : t) _act (v : TVar.t) (value : Value.t) = +(* FIXME: use on_event instead, watch (term + set of args) for congruence *) +let on_assign (self : t) _act (v : TVar.t) (value : Value.t) = match TVar.theory_view self.vst v with | Unin_const _ -> () | Unin_fun { f = _; args } -> @@ -65,4 +66,4 @@ let term_to_var_hooks (self : t) : _ list = let builder ((module A : ARG) as arg) : Core.Plugin.builder = Core.Plugin.make_builder ~name:"uf" ~create:(create arg) ~push_level - ~pop_levels ~decide ~propagate ~term_to_var_hooks () + ~pop_levels ~decide ~on_assign ~term_to_var_hooks () diff --git a/src/cdsat/watch1.ml b/src/cdsat/watch1.ml deleted file mode 100644 index 1a9b64e5..00000000 --- a/src/cdsat/watch1.ml +++ /dev/null @@ -1,38 +0,0 @@ -open Watch_utils_ - -type t = { vst: TVar.store; arr: TVar.t array; mutable alive: bool } - -let make vst l = { alive = true; vst; arr = Array.of_list l } -let[@inline] make_a vst arr : t = { alive = true; vst; arr } -let[@inline] alive self = self.alive -let[@inline] kill self = self.alive <- false - -let[@inline] iter (self : t) k = - if Array.length self.arr > 0 then k self.arr.(0) - -let init (self : t) t ~on_all_set : unit = - let i, all_set = find_watch_ self.vst self.arr 0 0 in - (* put watch first *) - Util.swap_array self.arr i 0; - TVar.add_watcher self.vst self.arr.(0) ~watcher:t; - if all_set then on_all_set (); - () - -let update (self : t) t ~watch ~on_all_set : Watch_res.t = - if self.alive then ( - (* find another watch. If none is present, keep the - current one and call [on_all_set]. *) - assert (self.arr.(0) == watch); - let i, all_set = find_watch_ self.vst self.arr 0 0 in - if all_set then ( - on_all_set (); - Watch_res.Watch_keep (* just keep current watch *) - ) else ( - (* use [i] as the watch *) - assert (i > 0); - Util.swap_array self.arr i 0; - TVar.add_watcher self.vst self.arr.(0) ~watcher:t; - Watch_res.Watch_remove - ) - ) else - Watch_res.Watch_remove diff --git a/src/cdsat/watch1.mli b/src/cdsat/watch1.mli deleted file mode 100644 index aa26a63f..00000000 --- a/src/cdsat/watch1.mli +++ /dev/null @@ -1,29 +0,0 @@ -(** 1-Watch Scheme *) - -type t - -val make : TVar.store -> TVar.t list -> t - -val make_a : TVar.store -> TVar.t array -> t -(** owns the array *) - -val alive : t -> bool -val kill : t -> unit - -val iter : t -> TVar.t Iter.t -(** current watch(es) *) - -val init : t -> TVar.t -> on_all_set:(unit -> unit) -> unit -(** [init w t ~on_all_set] initializes [w] (the watchlist) for - var [t], by finding an unassigned TVar.t in the watchlist and - registering [t] to it. - If all vars are set, then it watches the one with the highest level - and call [on_all_set ()] - *) - -val update : - t -> TVar.t -> watch:TVar.t -> on_all_set:(unit -> unit) -> Watch_res.t -(** [update w t ~watch ~on_all_set] updates [w] after [watch] - has been assigned. It looks for another TVar.t in [w] for [t] to watch. - If all vars are set, then it calls [on_all_set ()] -*) diff --git a/src/cdsat/watch2.ml b/src/cdsat/watch2.ml deleted file mode 100644 index b25765c1..00000000 --- a/src/cdsat/watch2.ml +++ /dev/null @@ -1,61 +0,0 @@ -open Watch_utils_ - -type t = { vst: TVar.store; arr: TVar.t array; mutable alive: bool } - -let dummy = [||] -let make vst l : t = { alive = true; vst; arr = Array.of_list l } -let[@inline] make_a vst arr : t = { vst; arr; alive = true } -let[@inline] alive self = self.alive -let[@inline] kill self = self.alive <- false - -let[@inline] iter (self : t) k = - if Array.length self.arr > 0 then ( - k self.arr.(0); - if Array.length self.arr > 1 then k self.arr.(1) - ) - -let init (self : t) t ~on_unit ~on_all_set : unit = - let i0, all_set0 = find_watch_ self.vst self.arr 0 0 in - Util.swap_array self.arr i0 0; - let i1, all_set1 = find_watch_ self.vst self.arr 1 0 in - Util.swap_array self.arr i1 1; - TVar.add_watcher self.vst self.arr.(0) ~watcher:t; - TVar.add_watcher self.vst self.arr.(1) ~watcher:t; - assert ( - if all_set0 then - all_set1 - else - true); - if all_set0 then - on_all_set () - else if all_set1 then ( - assert (not (TVar.has_value self.vst self.arr.(0))); - on_unit self.arr.(0) - ); - () - -let update (self : t) t ~watch ~on_unit ~on_all_set : Watch_res.t = - if self.alive then ( - (* find another watch. If none is present, keep the - current ones and call [on_unit] or [on_all_set]. *) - if self.arr.(0) == watch then - (* ensure that if there is only one watch, it's the first *) - Util.swap_array self.arr 0 1 - else - assert (self.arr.(1) == watch); - let i, all_set1 = find_watch_ self.vst self.arr 1 0 in - if all_set1 then ( - if TVar.has_value self.vst self.arr.(0) then - on_all_set () - else - on_unit self.arr.(0); - Watch_res.Watch_keep (* just keep current watch *) - ) else ( - (* use [i] as the second watch *) - assert (i > 1); - Util.swap_array self.arr i 1; - TVar.add_watcher self.vst self.arr.(1) ~watcher:t; - Watch_res.Watch_remove - ) - ) else - Watch_res.Watch_remove diff --git a/src/cdsat/watch2.mli b/src/cdsat/watch2.mli deleted file mode 100644 index 3400ab5f..00000000 --- a/src/cdsat/watch2.mli +++ /dev/null @@ -1,40 +0,0 @@ -(** 2-Watch Scheme *) - -type t - -val make : TVar.store -> TVar.t list -> t - -val make_a : TVar.store -> TVar.t array -> t -(** owns the array *) - -val iter : t -> TVar.t Iter.t -(** current watch(es) *) - -val kill : t -> unit -(** Disable the watch. It will be removed lazily. *) - -val alive : t -> bool -(** Is the watch alive? *) - -val init : - t -> TVar.t -> on_unit:(TVar.t -> unit) -> on_all_set:(unit -> unit) -> unit -(** [init w t ~on_all_set] initializes [w] (the watchlist) for - var [t], by finding an unassigned var in the watchlist and - registering [t] to it. - If exactly one TVar.t [u] is not set, then it calls [on_unit u]. - If all vars are set, then it watches the one with the highest level - and call [on_all_set ()] - *) - -val update : - t -> - TVar.t -> - watch:TVar.t -> - on_unit:(TVar.t -> unit) -> - on_all_set:(unit -> unit) -> - Watch_res.t -(** [update w t ~watch ~on_all_set] updates [w] after [watch] - has been assigned. It looks for another var in [w] for [t] to watch. - If exactly one var [u] is not set, then it calls [on_unit u]. - If all vars are set, then it calls [on_all_set ()] -*) diff --git a/src/cdsat/watch_res.ml b/src/cdsat/watch_res.ml deleted file mode 100644 index f14eeee8..00000000 --- a/src/cdsat/watch_res.ml +++ /dev/null @@ -1,3 +0,0 @@ -type t = - | Watch_keep (** Keep the watch *) - | Watch_remove (** Remove the watch *) diff --git a/src/cdsat/watch_schemes.ml b/src/cdsat/watch_schemes.ml new file mode 100644 index 00000000..9823f956 --- /dev/null +++ b/src/cdsat/watch_schemes.ml @@ -0,0 +1,186 @@ +type watch_update_res = + | Watch_keep (** Keep the watch *) + | Watch_remove (** Remove the watch *) + +(* find a term in [w] that is not assigned, or otherwise, + the one with highest level + @return index of term to watch, and [true] if all are assigned *) +let find_watch_ tst w ~idx0 : int * bool = + let rec loop i idx_with_highest_level = + if i = Array.length w then + idx_with_highest_level, true + else if TVar.has_value tst w.(i) then ( + let idx_with_highest_level = + if TVar.level tst w.(i) > TVar.level tst w.(idx_with_highest_level) then + i + else + idx_with_highest_level + in + loop (i + 1) idx_with_highest_level + ) else + i, false + in + loop idx0 0 + +module Make (Ev : sig + type t +end) = +struct + type handle = int + + module Handle_v_map = TVar.Dense_map (struct + type t = Veci.t + + let create () = Veci.create ~cap:2 () + end) + + type watch = + | No_watch + | Watch1 of { ev: Ev.t; arr: TVar.t array } + | Watch2 of { ev: Ev.t; arr: TVar.t array } + + type t = { + vst: TVar.store; + watches: watch Vec.t; + by_var: Handle_v_map.t; + (** maps a variable to the handles of its watchers *) + alive: Bitvec.t; + free_slots: Veci.t; + } + + let create vst : t = + { + vst; + watches = Vec.create (); + by_var = Handle_v_map.create (); + alive = Bitvec.create (); + free_slots = Veci.create (); + } + + type cb_ = unit:bool -> Ev.t -> unit + + (* allocate new watch *) + let make_watch_ (self : t) (w : watch) : handle = + if Veci.is_empty self.free_slots then ( + let h = Vec.size self.watches in + Vec.push self.watches w; + Bitvec.ensure_size self.alive (h + 1); + Bitvec.set self.alive h true; + h + ) else ( + let h = Veci.pop self.free_slots in + Bitvec.set self.alive h true; + Vec.set self.watches h w; + h + ) + + (* [h] is currently watching [v] *) + let set_watch (self : t) (v : TVar.t) (h : handle) : unit = + let vec = Handle_v_map.get self.by_var v in + Veci.push vec h + + let watch1 (self : t) (arr : TVar.t array) (ev : Ev.t) ~(f : cb_) : handle = + let h = make_watch_ self (Watch1 { arr; ev }) in + let i, all_set = find_watch_ self.vst arr ~idx0:0 in + (* put watched var first *) + Util.swap_array arr i 0; + set_watch self arr.(0) h; + if all_set then f ~unit:false ev; + h + + let watch2 (self : t) (arr : TVar.t array) (ev : Ev.t) ~(f : cb_) : handle = + let h = make_watch_ self (Watch2 { arr; ev }) in + (* put watched vars first *) + let i0, all_set0 = find_watch_ self.vst arr ~idx0:0 in + Util.swap_array arr i0 0; + let i1, all_set1 = find_watch_ self.vst arr ~idx0:1 in + Util.swap_array arr i1 1; + set_watch self arr.(0) h; + set_watch self arr.(1) h; + assert ( + if all_set0 then + all_set1 + else + true); + if all_set0 then + f ~unit:false ev + else if all_set1 then ( + assert (not (TVar.has_value self.vst arr.(0))); + f ~unit:true ev + ); + h + + (** disable watch. It will be removed from watchers next time they + are updated or next time {!gc} is called. *) + let kill (self : t) (h : handle) : unit = + if Bitvec.get self.alive h then ( + Vec.set self.watches h No_watch; + Bitvec.set self.alive h false + ) + + let[@inline] alive (self : t) (h : handle) : bool = Bitvec.get self.alive h + + let gc (self : t) : unit = + (* first, filter all dead watches from [self.by_var] *) + Handle_v_map.iter self.by_var ~f:(fun _v handles -> + Veci.filter_in_place (alive self) handles); + (* then, mark the dead watch slots for reuse *) + Vec.iteri self.watches ~f:(fun i _w -> + if not (alive self i) then Veci.push self.free_slots i) + + (* update a single watch *) + let update1 (self : t) (h : handle) (w : watch) ~updated_var ~f : + watch_update_res = + match w with + | No_watch -> Watch_remove + | _ when not (alive self h) -> Watch_remove + | Watch1 { arr; ev } -> + (* find another watch. If none is present, keep the + current one and call [f]. *) + assert (TVar.equal arr.(0) updated_var); + let i, all_set = find_watch_ self.vst arr ~idx0:0 in + if all_set then ( + f ~unit:false ev; + Watch_keep (* just keep current watch *) + ) else ( + (* use [i] as the watch *) + assert (i > 0); + Util.swap_array arr i 0; + set_watch self arr.(0) h; + Watch_remove + ) + | Watch2 { arr; ev } -> + (* find another watch. If none is present, keep the + current ones and call [f]. *) + if TVar.equal arr.(0) updated_var then + (* ensure that if there is only one watch, it's the first *) + Util.swap_array arr 0 1 + else + assert (TVar.equal arr.(1) updated_var); + let i, all_set1 = find_watch_ self.vst arr ~idx0:1 in + if all_set1 then ( + if TVar.has_value self.vst arr.(0) then + f ~unit:false ev + else + f ~unit:true ev; + (* just keep current watch *) + Watch_keep + ) else ( + (* use [i] as the second watch *) + assert (i > 1); + Util.swap_array arr i 1; + set_watch self arr.(1) h; + Watch_remove + ) + + let update (self : t) (v : TVar.t) ~(f : cb_) : unit = + let vec = Handle_v_map.get self.by_var v in + let i = ref 0 in + while !i < Veci.size vec do + let handle = Veci.get vec !i in + let w = Vec.get self.watches handle in + match update1 self handle w ~updated_var:v ~f with + | Watch_keep -> incr i + | Watch_remove -> Veci.fast_remove vec !i + done +end diff --git a/src/cdsat/watch_schemes.mli b/src/cdsat/watch_schemes.mli new file mode 100644 index 00000000..4ec51f14 --- /dev/null +++ b/src/cdsat/watch_schemes.mli @@ -0,0 +1,31 @@ +(** Watch schemes *) + +module Make (Ev : sig + type t +end) : sig + type t + type handle + + val create : TVar.store -> t + (** New set of watchers *) + + val watch2 : + t -> TVar.t array -> Ev.t -> f:(unit:bool -> Ev.t -> unit) -> handle + (** 2-watch scheme on these variables. *) + + val watch1 : + t -> TVar.t array -> Ev.t -> f:(unit:bool -> Ev.t -> unit) -> handle + (** 1-watch scheme on these variables. *) + + val kill : t -> handle -> unit + (** Disable watch *) + + val gc : t -> unit + (** Reclaim slots that have been killed *) + + val update : t -> TVar.t -> f:(unit:bool -> Ev.t -> unit) -> unit + (** [update watches v ~f] updates watchers that contain [v], + and calls [f ~unit ev] for each event whose watch just saturated. + [unit] is true if the watch is a 2-watch that became unit; [false] in + any other case (including a fully saturated 2-watch) *) +end diff --git a/src/cdsat/watch_utils_.ml b/src/cdsat/watch_utils_.ml deleted file mode 100644 index fd3ea00d..00000000 --- a/src/cdsat/watch_utils_.ml +++ /dev/null @@ -1,16 +0,0 @@ -(* find a term in [w] that is not assigned, or otherwise, - the one with highest level - @return index of term to watch, and [true] if all are assigned *) -let rec find_watch_ tst w i highest : int * bool = - if i = Array.length w then - highest, true - else if TVar.has_value tst w.(i) then ( - let highest = - if TVar.level tst w.(i) > TVar.level tst w.(highest) then - i - else - highest - in - find_watch_ tst w (i + 1) highest - ) else - i, false