feat(cdsat): revamp watches

This commit is contained in:
Simon Cruanes 2022-11-08 10:53:38 -05:00
parent c34e648148
commit c5f00b5204
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
14 changed files with 407 additions and 244 deletions

View file

@ -25,9 +25,8 @@ type store = {
value: Value.t option Vec.t;
reason: reason Vec.t;
theory_views: theory_view Vec.t;
watches: t Vec.t Vec.t;
has_value: Bitvec.t;
new_vars: Vec_of.t;
new_vars: Vec_of.t; (* TODO: a recycle vec to reuse identifiers *)
}
(* create a new variable *)
@ -39,7 +38,6 @@ let new_var_ (self : store) ~term:(term_for_v : Term.t) ~theory_view () : t =
term;
level;
value;
watches;
reason;
theory_views;
has_value;
@ -52,7 +50,6 @@ let new_var_ (self : store) ~term:(term_for_v : Term.t) ~theory_view () : t =
Vec.push value None;
(* fake *)
Vec.push reason dummy_reason_;
Vec.push watches (Vec.create ());
Vec.push theory_views theory_view;
Bitvec.ensure_size has_value (v + 1);
Bitvec.set has_value v false;
@ -65,6 +62,9 @@ let[@inline] get_of_term (self : store) (t : Term.t) : t option =
let[@inline] has_value (self : store) (v : t) : bool =
Bitvec.get self.has_value v
let[@inline] equal (a : t) (b : t) = a = b
let[@inline] compare (a : t) (b : t) = compare a b
let[@inline] hash (a : t) = CCHash.int a
let[@inline] level (self : store) (v : t) : int = Veci.get self.level v
let[@inline] value (self : store) (v : t) : _ option = Vec.get self.value v
let[@inline] theory_view (self : store) (v : t) = Vec.get self.theory_views v
@ -77,10 +77,6 @@ let[@inline] bool_value (self : store) (v : t) : _ option =
let[@inline] term (self : store) (v : t) : Term.t = Vec.get self.term v
let[@inline] reason (self : store) (v : t) : reason = Vec.get self.reason v
let[@inline] watchers (self : store) (v : t) : t Vec.t = Vec.get self.watches v
let[@inline] add_watcher (self : store) (v : t) ~watcher : unit =
Vec.push (watchers self v) watcher
let assign (self : store) (v : t) ~value ~level ~reason : unit =
Log.debugf 50 (fun k ->
@ -115,7 +111,6 @@ module Store = struct
reason = Vec.create ();
term = Vec.create ();
level = Veci.create ();
watches = Vec.create ();
value = Vec.create ();
theory_views = Vec.create ();
has_value = Bitvec.create ();
@ -133,6 +128,28 @@ module Tbl = Util.Int_tbl
module Set = Util.Int_set
module Map = Util.Int_map
module Dense_map (Elt : sig
type t
val create : unit -> t
end) =
struct
type elt = Elt.t
type t = { v: elt Vec.t } [@@unboxed]
let create () : t = { v = Vec.create () }
let[@inline] get self v =
Vec.ensure_size_with self.v Elt.create (v + 1);
Vec.get self.v v
let[@inline] set self v x =
Vec.ensure_size_with self.v Elt.create (v + 1);
Vec.set self.v v x
let[@inline] iter self ~f = Vec.iteri self.v ~f
end
module Internal = struct
let create (self : store) (t : Term.t) ~theory_view : t =
assert (not @@ Term.Weak_map.mem self.of_term t);

View file

@ -23,6 +23,8 @@ end
module Vec_of : Vec_sig.S with type elt := t
(** Vector of variables *)
include Sidekick_sigs.EQ_ORD_HASH with type t := t
type store = Store.t
type reason =
@ -53,13 +55,6 @@ val theory_view : store -> t -> theory_view
val assign : store -> t -> value:Value.t -> level:int -> reason:reason -> unit
val unassign : store -> t -> unit
val watchers : store -> t -> t Vec.t
(** [watchers store t] is a vector of other variables watching [t],
generally updated via {!Watch1} and {!Watch2}.
These other variables are notified when [t] is assigned. *)
val add_watcher : store -> t -> watcher:t -> unit
val pop_new_var : store -> t option
(** Pop a new variable if any, or return [None] *)
@ -69,6 +64,23 @@ module Tbl : CCHashtbl.S with type key = t
module Set : CCSet.S with type elt = t
module Map : CCMap.S with type key = t
(** A map optimized for dense storage.
This is useful when most variables have an entry in the map. *)
module Dense_map (Elt : sig
type t
val create : unit -> t
end) : sig
type elt = Elt.t
type t
val create : unit -> t
val get : t -> var -> elt
val set : t -> var -> elt -> unit
val iter : t -> f:(var -> elt -> unit) -> unit
end
(**/**)
module Internal : sig

View file

@ -19,6 +19,19 @@ type pending_assignment = {
reason: Reason.t;
}
type plugin_id = int
(** Each plugin gets a unique identifier *)
type plugin_event = ..
type watch_request =
| Watch2 of TVar.t array * plugin_event
| Watch1 of TVar.t array * plugin_event
module Watches = Watch_schemes.Make (struct
type t = plugin_id * plugin_event
end)
type t = {
tst: Term.store;
vst: TVar.store;
@ -29,6 +42,7 @@ type t = {
term_to_var: Term_to_var.t;
vars_to_decide: Vars_to_decide.t;
pending_assignments: pending_assignment Vec.t;
watches: Watches.t;
mutable last_res: Check_res.t option;
proof_tracer: Proof.Tracer.t;
n_conflicts: int Stat.counter;
@ -36,7 +50,7 @@ type t = {
n_restarts: int Stat.counter;
}
and plugin_action = t
and plugin_action = t * plugin_id
(* FIXME:
- add [on_add_var: TVar.t -> unit]
@ -51,12 +65,15 @@ and plugin_action = t
and plugin =
| P : {
st: 'st;
id: plugin_id;
name: string;
push_level: 'st -> unit;
pop_levels: 'st -> int -> unit;
decide: 'st -> TVar.t -> Value.t option;
propagate: 'st -> plugin_action -> TVar.t -> Value.t -> unit;
on_assign: 'st -> plugin_action -> TVar.t -> Value.t -> unit;
on_event: 'st -> plugin_action -> unit:bool -> plugin_event -> unit;
term_to_var_hooks: 'st -> Term_to_var.hook list;
on_add_var: 'st -> TVar.t -> watch_request list;
}
-> plugin
@ -71,6 +88,7 @@ let create ?(stats = Stat.create ()) ~arg tst vst ~proof_tracer () : t =
pending_assignments = Vec.create ();
term_to_var = Term_to_var.create vst;
vars_to_decide = Vars_to_decide.create ();
watches = Watches.create vst;
last_res = None;
proof_tracer;
n_restarts = Stat.mk_int stats "cdsat.restarts";
@ -80,6 +98,7 @@ let create ?(stats = Stat.create ()) ~arg tst vst ~proof_tracer () : t =
let[@inline] trail self = self.trail
let[@inline] iter_plugins self ~f = Vec.iter ~f self.plugins
let[@inline] get_plugin (self : t) (id : plugin_id) = Vec.get self.plugins id
let[@inline] tst self = self.tst
let[@inline] vst self = self.vst
let[@inline] last_res self = self.last_res
@ -88,15 +107,35 @@ let[@inline] last_res self = self.last_res
module Plugin = struct
type t = plugin
type builder = TVar.store -> t
type builder = id:plugin_id -> TVar.store -> t
let[@inline] name (P p) = p.name
let make_builder ~name ~create ~push_level ~pop_levels ~decide ~propagate
~term_to_var_hooks () : builder =
fun vst ->
type nonrec event = plugin_event = ..
type nonrec watch_request = watch_request =
| Watch2 of TVar.t array * event
| Watch1 of TVar.t array * event
let make_builder ~name ~create ~push_level ~pop_levels
?(decide = fun _ _ -> None) ?(on_assign = fun _ _ _ _ -> ())
?(on_event = fun _ _ ~unit:_ _ -> ()) ?(on_add_var = fun _ _ -> [])
?(term_to_var_hooks = fun _ -> []) () : builder =
fun ~id vst ->
let st = create vst in
P { name; st; push_level; pop_levels; decide; propagate; term_to_var_hooks }
P
{
name;
id;
st;
push_level;
pop_levels;
decide;
on_assign;
on_event;
term_to_var_hooks;
on_add_var;
}
end
(* backtracking *)
@ -118,6 +157,7 @@ let pop_levels (self : t) n : unit =
trail;
plugins;
term_to_var = _;
watches = _;
vars_to_decide = _;
pending_assignments;
last_res = _;
@ -149,7 +189,8 @@ let add_term_to_var_hook self h = Term_to_var.add_hook self.term_to_var h
(* plugins *)
let add_plugin self (pb : Plugin.builder) : unit =
let (P p as plugin) = pb self.vst in
let id = Vec.size self.plugins in
let (P p as plugin) = pb ~id self.vst in
Vec.push self.plugins plugin;
List.iter (add_term_to_var_hook self) (p.term_to_var_hooks p.st)
@ -157,8 +198,8 @@ let add_plugin self (pb : Plugin.builder) : unit =
let add_ty (_self : t) ~ty:_ : unit = ()
(* Assign [v <- value] for [reason] at [level].
This assignment is delayed. *)
(** Assign [v <- value] for [reason] at [level].
This assignment is delayed. *)
let assign (self : t) (v : TVar.t) ~(value : Value.t) ~level:v_level ~reason :
unit =
Log.debugf 50 (fun k ->
@ -173,27 +214,37 @@ let raise_conflict (c : Conflict.t) : 'a = raise (E_conflict c)
(* add pending assignments to the trail. This might trigger a conflict
in case an assignment contradicts an already existing assignment. *)
let perform_pending_assignments (self : t) : unit =
while not (Vec.is_empty self.pending_assignments) do
let { var = v; level = v_level; value; reason } =
Vec.pop_exn self.pending_assignments
in
match TVar.value self.vst v with
| None ->
TVar.assign self.vst v ~value ~level:v_level ~reason;
Trail.push_assignment self.trail v
| Some value' when Value.equal value value' -> () (* idempotent *)
| Some _value' ->
(* conflict should only occur on booleans since they're the only
propagation-able variables *)
assert (Term.has_ty_bool (TVar.term self.vst v));
Log.debugf 0 (fun k ->
k "TODO: conflict (incompatible values for %a)" (TVar.pp self.vst) v);
raise_conflict
@@ Conflict.make self.vst ~lit:(TLit.make true v) ~propagate_reason:reason
()
let perform_pending_assignments_real_ (self : t) : unit =
while
match Vec.pop self.pending_assignments with
| None -> false
| Some { var = v; level = v_level; value; reason } ->
(match TVar.value self.vst v with
| None ->
(* assign [v], put it on the trail. Do not notify watchers yet. *)
TVar.assign self.vst v ~value ~level:v_level ~reason;
Trail.push_assignment self.trail v
| Some value' when Value.equal value value' -> () (* idempotent *)
| Some _value' ->
(* conflict should only occur on booleans since they're the only
propagation-able variables *)
assert (Term.has_ty_bool (TVar.term self.vst v));
Log.debugf 0 (fun k ->
k "TODO: conflict (incompatible values for %a)" (TVar.pp self.vst) v);
raise_conflict
@@ Conflict.make self.vst ~lit:(TLit.make true v)
~propagate_reason:reason ());
true
do
()
done
let[@inline] perform_pending_assignments (self : t) : unit =
if not (Vec.is_empty self.pending_assignments) then
perform_pending_assignments_real_ self
(** Perform unit propagation in theories. Returns [Some c] if [c]
is a conflict detected during propagation. *)
let propagate (self : t) : Conflict.t option =
let@ () = Profile.with_ "cdsat.propagate" in
try
@ -213,7 +264,16 @@ let propagate (self : t) : Conflict.t option =
| None -> assert false
in
iter_plugins self ~f:(fun (P p) -> p.propagate p.st self var value);
(* directly give assignment to plugins *)
iter_plugins self ~f:(fun (P p) ->
p.on_assign p.st (self, p.id) var value;
perform_pending_assignments self);
(* notifier watchers *)
Watches.update self.watches var ~f:(fun ~unit (pl_id, ev) ->
let (P p) = get_plugin self pl_id in
p.on_event p.st (self, p.id) ~unit ev;
perform_pending_assignments self);
(* move to next var *)
Trail.set_head self.trail (Trail.head self.trail + 1)
@ -234,6 +294,7 @@ let make_sat_res (_self : t) : Check_res.sat_result =
iter_true_lits = (fun _ -> assert false);
}
(** Make a decision, or return [`Full_model] *)
let rec decide (self : t) : [ `Decided | `Full_model ] =
match Vars_to_decide.pop_next self.vars_to_decide with
| None -> `Full_model
@ -260,6 +321,7 @@ let rec decide (self : t) : [ `Decided | `Full_model ] =
`Decided
)
(** Solve satisfiability of the current set of assertions *)
let solve ~on_exit ~on_progress ~should_stop ~assumptions (self : t) :
Check_res.t =
let@ () = Profile.with_ "cdsat.solve" in
@ -319,8 +381,26 @@ let solve ~on_exit ~on_progress ~should_stop ~assumptions (self : t) :
module Plugin_action = struct
type t = plugin_action
let[@inline] propagate (self : t) var value reason : unit =
let[@inline] propagate ((self, _) : t) var value reason : unit =
assign self var ~value ~level:(Reason.level reason) ~reason
let term_to_var = term_to_var
let term_to_var (self, _) t : TVar.t = term_to_var self t
let watch1 ((self, pl_id) : t) (vars : _ array) (ev : plugin_event) : unit =
let _h : Watches.handle =
Watches.watch1 self.watches vars (pl_id, ev) ~f:(fun ~unit (_, ev) ->
let (P p) = get_plugin self pl_id in
p.on_event p.st (self, pl_id) ~unit ev;
perform_pending_assignments self)
in
()
let watch2 ((self, pl_id) : t) (vars : _ array) (ev : plugin_event) : unit =
let _h : Watches.handle =
Watches.watch2 self.watches vars (pl_id, ev) ~f:(fun ~unit (_, ev) ->
let (P p) = get_plugin self pl_id in
p.on_event p.st (self, pl_id) ~unit ev;
perform_pending_assignments self)
in
()
end

View file

@ -11,11 +11,25 @@ end
(** {2 Plugins} *)
type plugin_event = ..
(** Actions passed to plugins *)
module Plugin_action : sig
type t
val propagate : t -> TVar.t -> Value.t -> Reason.t -> unit
(** Propagate given assignment *)
val term_to_var : t -> Term.t -> TVar.t
(** Convert a term to a variable *)
val watch1 : t -> TVar.t array -> plugin_event -> unit
(** Create a watcher for the given set of variables, which will trigger
the event *)
val watch2 : t -> TVar.t array -> plugin_event -> unit
(** Create a watcher for the given set of variables, which will trigger
the event *)
end
(** Core plugin *)
@ -25,14 +39,22 @@ module Plugin : sig
val name : t -> string
type event = plugin_event = ..
type watch_request =
| Watch2 of TVar.t array * event
| Watch1 of TVar.t array * event
val make_builder :
name:string ->
create:(TVar.store -> 'st) ->
push_level:('st -> unit) ->
pop_levels:('st -> int -> unit) ->
decide:('st -> TVar.t -> Value.t option) ->
propagate:('st -> Plugin_action.t -> TVar.t -> Value.t -> unit) ->
term_to_var_hooks:('st -> Term_to_var.hook list) ->
?decide:('st -> TVar.t -> Value.t option) ->
?on_assign:('st -> Plugin_action.t -> TVar.t -> Value.t -> unit) ->
?on_event:('st -> Plugin_action.t -> unit:bool -> event -> unit) ->
?on_add_var:('st -> TVar.t -> watch_request list) ->
?term_to_var_hooks:('st -> Term_to_var.hook list) ->
unit ->
builder
end

View file

@ -32,12 +32,13 @@ let decide (self : t) (v : TVar.t) : Value.t option =
Some (Term.true_ self.tst)
| _ -> None
let propagate (self : t) (act : Core.Plugin_action.t) (v : TVar.t)
let on_assign (self : t) (act : Core.Plugin_action.t) (v : TVar.t)
(value : Value.t) : unit =
Log.debugf 0 (fun k ->
k "(@[bool-plugin.propagate %a@])" (TVar.pp self.vst) v);
k "(@[bool-plugin.on-assign %a@])" (TVar.pp self.vst) v);
()
(* TODO: BCP *)
(* TODO: BCP via on_event *)
let term_to_var_hooks (self : t) : _ list =
let (module A) = self.arg in
@ -72,4 +73,4 @@ let builder ((module A : ARG) as arg) : Core.Plugin.builder =
{ arg; vst; tst }
in
Core.Plugin.make_builder ~name:"bool" ~create ~push_level ~pop_levels ~decide
~propagate ~term_to_var_hooks ()
~on_assign ~term_to_var_hooks ()

View file

@ -35,7 +35,8 @@ let pop_levels self n = Cong_tbl.pop_levels self.cong_table n
(* let other theories decide, depending on the type *)
let decide _ _ = None
let propagate (self : t) _act (v : TVar.t) (value : Value.t) =
(* FIXME: use on_event instead, watch (term + set of args) for congruence *)
let on_assign (self : t) _act (v : TVar.t) (value : Value.t) =
match TVar.theory_view self.vst v with
| Unin_const _ -> ()
| Unin_fun { f = _; args } ->
@ -65,4 +66,4 @@ let term_to_var_hooks (self : t) : _ list =
let builder ((module A : ARG) as arg) : Core.Plugin.builder =
Core.Plugin.make_builder ~name:"uf" ~create:(create arg) ~push_level
~pop_levels ~decide ~propagate ~term_to_var_hooks ()
~pop_levels ~decide ~on_assign ~term_to_var_hooks ()

View file

@ -1,38 +0,0 @@
open Watch_utils_
type t = { vst: TVar.store; arr: TVar.t array; mutable alive: bool }
let make vst l = { alive = true; vst; arr = Array.of_list l }
let[@inline] make_a vst arr : t = { alive = true; vst; arr }
let[@inline] alive self = self.alive
let[@inline] kill self = self.alive <- false
let[@inline] iter (self : t) k =
if Array.length self.arr > 0 then k self.arr.(0)
let init (self : t) t ~on_all_set : unit =
let i, all_set = find_watch_ self.vst self.arr 0 0 in
(* put watch first *)
Util.swap_array self.arr i 0;
TVar.add_watcher self.vst self.arr.(0) ~watcher:t;
if all_set then on_all_set ();
()
let update (self : t) t ~watch ~on_all_set : Watch_res.t =
if self.alive then (
(* find another watch. If none is present, keep the
current one and call [on_all_set]. *)
assert (self.arr.(0) == watch);
let i, all_set = find_watch_ self.vst self.arr 0 0 in
if all_set then (
on_all_set ();
Watch_res.Watch_keep (* just keep current watch *)
) else (
(* use [i] as the watch *)
assert (i > 0);
Util.swap_array self.arr i 0;
TVar.add_watcher self.vst self.arr.(0) ~watcher:t;
Watch_res.Watch_remove
)
) else
Watch_res.Watch_remove

View file

@ -1,29 +0,0 @@
(** 1-Watch Scheme *)
type t
val make : TVar.store -> TVar.t list -> t
val make_a : TVar.store -> TVar.t array -> t
(** owns the array *)
val alive : t -> bool
val kill : t -> unit
val iter : t -> TVar.t Iter.t
(** current watch(es) *)
val init : t -> TVar.t -> on_all_set:(unit -> unit) -> unit
(** [init w t ~on_all_set] initializes [w] (the watchlist) for
var [t], by finding an unassigned TVar.t in the watchlist and
registering [t] to it.
If all vars are set, then it watches the one with the highest level
and call [on_all_set ()]
*)
val update :
t -> TVar.t -> watch:TVar.t -> on_all_set:(unit -> unit) -> Watch_res.t
(** [update w t ~watch ~on_all_set] updates [w] after [watch]
has been assigned. It looks for another TVar.t in [w] for [t] to watch.
If all vars are set, then it calls [on_all_set ()]
*)

View file

@ -1,61 +0,0 @@
open Watch_utils_
type t = { vst: TVar.store; arr: TVar.t array; mutable alive: bool }
let dummy = [||]
let make vst l : t = { alive = true; vst; arr = Array.of_list l }
let[@inline] make_a vst arr : t = { vst; arr; alive = true }
let[@inline] alive self = self.alive
let[@inline] kill self = self.alive <- false
let[@inline] iter (self : t) k =
if Array.length self.arr > 0 then (
k self.arr.(0);
if Array.length self.arr > 1 then k self.arr.(1)
)
let init (self : t) t ~on_unit ~on_all_set : unit =
let i0, all_set0 = find_watch_ self.vst self.arr 0 0 in
Util.swap_array self.arr i0 0;
let i1, all_set1 = find_watch_ self.vst self.arr 1 0 in
Util.swap_array self.arr i1 1;
TVar.add_watcher self.vst self.arr.(0) ~watcher:t;
TVar.add_watcher self.vst self.arr.(1) ~watcher:t;
assert (
if all_set0 then
all_set1
else
true);
if all_set0 then
on_all_set ()
else if all_set1 then (
assert (not (TVar.has_value self.vst self.arr.(0)));
on_unit self.arr.(0)
);
()
let update (self : t) t ~watch ~on_unit ~on_all_set : Watch_res.t =
if self.alive then (
(* find another watch. If none is present, keep the
current ones and call [on_unit] or [on_all_set]. *)
if self.arr.(0) == watch then
(* ensure that if there is only one watch, it's the first *)
Util.swap_array self.arr 0 1
else
assert (self.arr.(1) == watch);
let i, all_set1 = find_watch_ self.vst self.arr 1 0 in
if all_set1 then (
if TVar.has_value self.vst self.arr.(0) then
on_all_set ()
else
on_unit self.arr.(0);
Watch_res.Watch_keep (* just keep current watch *)
) else (
(* use [i] as the second watch *)
assert (i > 1);
Util.swap_array self.arr i 1;
TVar.add_watcher self.vst self.arr.(1) ~watcher:t;
Watch_res.Watch_remove
)
) else
Watch_res.Watch_remove

View file

@ -1,40 +0,0 @@
(** 2-Watch Scheme *)
type t
val make : TVar.store -> TVar.t list -> t
val make_a : TVar.store -> TVar.t array -> t
(** owns the array *)
val iter : t -> TVar.t Iter.t
(** current watch(es) *)
val kill : t -> unit
(** Disable the watch. It will be removed lazily. *)
val alive : t -> bool
(** Is the watch alive? *)
val init :
t -> TVar.t -> on_unit:(TVar.t -> unit) -> on_all_set:(unit -> unit) -> unit
(** [init w t ~on_all_set] initializes [w] (the watchlist) for
var [t], by finding an unassigned var in the watchlist and
registering [t] to it.
If exactly one TVar.t [u] is not set, then it calls [on_unit u].
If all vars are set, then it watches the one with the highest level
and call [on_all_set ()]
*)
val update :
t ->
TVar.t ->
watch:TVar.t ->
on_unit:(TVar.t -> unit) ->
on_all_set:(unit -> unit) ->
Watch_res.t
(** [update w t ~watch ~on_all_set] updates [w] after [watch]
has been assigned. It looks for another var in [w] for [t] to watch.
If exactly one var [u] is not set, then it calls [on_unit u].
If all vars are set, then it calls [on_all_set ()]
*)

View file

@ -1,3 +0,0 @@
type t =
| Watch_keep (** Keep the watch *)
| Watch_remove (** Remove the watch *)

186
src/cdsat/watch_schemes.ml Normal file
View file

@ -0,0 +1,186 @@
type watch_update_res =
| Watch_keep (** Keep the watch *)
| Watch_remove (** Remove the watch *)
(* find a term in [w] that is not assigned, or otherwise,
the one with highest level
@return index of term to watch, and [true] if all are assigned *)
let find_watch_ tst w ~idx0 : int * bool =
let rec loop i idx_with_highest_level =
if i = Array.length w then
idx_with_highest_level, true
else if TVar.has_value tst w.(i) then (
let idx_with_highest_level =
if TVar.level tst w.(i) > TVar.level tst w.(idx_with_highest_level) then
i
else
idx_with_highest_level
in
loop (i + 1) idx_with_highest_level
) else
i, false
in
loop idx0 0
module Make (Ev : sig
type t
end) =
struct
type handle = int
module Handle_v_map = TVar.Dense_map (struct
type t = Veci.t
let create () = Veci.create ~cap:2 ()
end)
type watch =
| No_watch
| Watch1 of { ev: Ev.t; arr: TVar.t array }
| Watch2 of { ev: Ev.t; arr: TVar.t array }
type t = {
vst: TVar.store;
watches: watch Vec.t;
by_var: Handle_v_map.t;
(** maps a variable to the handles of its watchers *)
alive: Bitvec.t;
free_slots: Veci.t;
}
let create vst : t =
{
vst;
watches = Vec.create ();
by_var = Handle_v_map.create ();
alive = Bitvec.create ();
free_slots = Veci.create ();
}
type cb_ = unit:bool -> Ev.t -> unit
(* allocate new watch *)
let make_watch_ (self : t) (w : watch) : handle =
if Veci.is_empty self.free_slots then (
let h = Vec.size self.watches in
Vec.push self.watches w;
Bitvec.ensure_size self.alive (h + 1);
Bitvec.set self.alive h true;
h
) else (
let h = Veci.pop self.free_slots in
Bitvec.set self.alive h true;
Vec.set self.watches h w;
h
)
(* [h] is currently watching [v] *)
let set_watch (self : t) (v : TVar.t) (h : handle) : unit =
let vec = Handle_v_map.get self.by_var v in
Veci.push vec h
let watch1 (self : t) (arr : TVar.t array) (ev : Ev.t) ~(f : cb_) : handle =
let h = make_watch_ self (Watch1 { arr; ev }) in
let i, all_set = find_watch_ self.vst arr ~idx0:0 in
(* put watched var first *)
Util.swap_array arr i 0;
set_watch self arr.(0) h;
if all_set then f ~unit:false ev;
h
let watch2 (self : t) (arr : TVar.t array) (ev : Ev.t) ~(f : cb_) : handle =
let h = make_watch_ self (Watch2 { arr; ev }) in
(* put watched vars first *)
let i0, all_set0 = find_watch_ self.vst arr ~idx0:0 in
Util.swap_array arr i0 0;
let i1, all_set1 = find_watch_ self.vst arr ~idx0:1 in
Util.swap_array arr i1 1;
set_watch self arr.(0) h;
set_watch self arr.(1) h;
assert (
if all_set0 then
all_set1
else
true);
if all_set0 then
f ~unit:false ev
else if all_set1 then (
assert (not (TVar.has_value self.vst arr.(0)));
f ~unit:true ev
);
h
(** disable watch. It will be removed from watchers next time they
are updated or next time {!gc} is called. *)
let kill (self : t) (h : handle) : unit =
if Bitvec.get self.alive h then (
Vec.set self.watches h No_watch;
Bitvec.set self.alive h false
)
let[@inline] alive (self : t) (h : handle) : bool = Bitvec.get self.alive h
let gc (self : t) : unit =
(* first, filter all dead watches from [self.by_var] *)
Handle_v_map.iter self.by_var ~f:(fun _v handles ->
Veci.filter_in_place (alive self) handles);
(* then, mark the dead watch slots for reuse *)
Vec.iteri self.watches ~f:(fun i _w ->
if not (alive self i) then Veci.push self.free_slots i)
(* update a single watch *)
let update1 (self : t) (h : handle) (w : watch) ~updated_var ~f :
watch_update_res =
match w with
| No_watch -> Watch_remove
| _ when not (alive self h) -> Watch_remove
| Watch1 { arr; ev } ->
(* find another watch. If none is present, keep the
current one and call [f]. *)
assert (TVar.equal arr.(0) updated_var);
let i, all_set = find_watch_ self.vst arr ~idx0:0 in
if all_set then (
f ~unit:false ev;
Watch_keep (* just keep current watch *)
) else (
(* use [i] as the watch *)
assert (i > 0);
Util.swap_array arr i 0;
set_watch self arr.(0) h;
Watch_remove
)
| Watch2 { arr; ev } ->
(* find another watch. If none is present, keep the
current ones and call [f]. *)
if TVar.equal arr.(0) updated_var then
(* ensure that if there is only one watch, it's the first *)
Util.swap_array arr 0 1
else
assert (TVar.equal arr.(1) updated_var);
let i, all_set1 = find_watch_ self.vst arr ~idx0:1 in
if all_set1 then (
if TVar.has_value self.vst arr.(0) then
f ~unit:false ev
else
f ~unit:true ev;
(* just keep current watch *)
Watch_keep
) else (
(* use [i] as the second watch *)
assert (i > 1);
Util.swap_array arr i 1;
set_watch self arr.(1) h;
Watch_remove
)
let update (self : t) (v : TVar.t) ~(f : cb_) : unit =
let vec = Handle_v_map.get self.by_var v in
let i = ref 0 in
while !i < Veci.size vec do
let handle = Veci.get vec !i in
let w = Vec.get self.watches handle in
match update1 self handle w ~updated_var:v ~f with
| Watch_keep -> incr i
| Watch_remove -> Veci.fast_remove vec !i
done
end

View file

@ -0,0 +1,31 @@
(** Watch schemes *)
module Make (Ev : sig
type t
end) : sig
type t
type handle
val create : TVar.store -> t
(** New set of watchers *)
val watch2 :
t -> TVar.t array -> Ev.t -> f:(unit:bool -> Ev.t -> unit) -> handle
(** 2-watch scheme on these variables. *)
val watch1 :
t -> TVar.t array -> Ev.t -> f:(unit:bool -> Ev.t -> unit) -> handle
(** 1-watch scheme on these variables. *)
val kill : t -> handle -> unit
(** Disable watch *)
val gc : t -> unit
(** Reclaim slots that have been killed *)
val update : t -> TVar.t -> f:(unit:bool -> Ev.t -> unit) -> unit
(** [update watches v ~f] updates watchers that contain [v],
and calls [f ~unit ev] for each event whose watch just saturated.
[unit] is true if the watch is a 2-watch that became unit; [false] in
any other case (including a fully saturated 2-watch) *)
end

View file

@ -1,16 +0,0 @@
(* find a term in [w] that is not assigned, or otherwise,
the one with highest level
@return index of term to watch, and [true] if all are assigned *)
let rec find_watch_ tst w i highest : int * bool =
if i = Array.length w then
highest, true
else if TVar.has_value tst w.(i) then (
let highest =
if TVar.level tst w.(i) > TVar.level tst w.(highest) then
i
else
highest
in
find_watch_ tst w (i + 1) highest
) else
i, false