diff --git a/src/cdsat/TVar.ml b/src/cdsat/TVar.ml index 858a2307..7cafa659 100644 --- a/src/cdsat/TVar.ml +++ b/src/cdsat/TVar.ml @@ -15,6 +15,7 @@ type store = { level: Veci.t; value: Value.t option Vec.t; reason: reason Vec.t; + watches: t Vec.t Vec.t; has_value: Bitvec.t; new_vars: Vec_of.t; } @@ -26,15 +27,25 @@ and reason = (* create a new variable *) let new_var_ (self : store) ~term:(term_for_v : Term.t) () : t = let v : t = Vec.size self.term in - let { tst = _; of_term = _; term; level; value; reason; has_value; new_vars } - = + let { + tst = _; + of_term = _; + term; + level; + value; + watches; + reason; + has_value; + new_vars; + } = self in Vec.push term term_for_v; Veci.push level (-1); Vec.push value None; - Vec.push reason Decide; (* fake *) + Vec.push reason Decide; + Vec.push watches (Vec.create ()); Bitvec.ensure_size has_value (v + 1); Bitvec.set has_value v false; Vec_of.push new_vars v; @@ -50,11 +61,24 @@ let of_term (self : store) (t : Term.t) : t = will allow the variable to be properly defined in one theory? *) v -let has_value (self : store) (v : t) : bool = Bitvec.get self.has_value v -let level (self : store) (v : t) : int = Veci.get self.level v -let value (self : store) (v : t) : _ option = Vec.get self.value v -let term (self : store) (v : t) : Term.t = Vec.get self.term v -let reason (self : store) (v : t) : reason = Vec.get self.reason v +let[@inline] has_value (self : store) (v : t) : bool = + Bitvec.get self.has_value v + +let[@inline] level (self : store) (v : t) : int = Veci.get self.level v +let[@inline] value (self : store) (v : t) : _ option = Vec.get self.value v + +let[@inline] set_value (self : store) (v : t) value : unit = + Vec.set self.value v (Some value) + +let[@inline] unset_value (self : store) (v : t) : unit = + Vec.set self.value v None + +let[@inline] term (self : store) (v : t) : Term.t = Vec.get self.term v +let[@inline] reason (self : store) (v : t) : reason = Vec.get self.reason v +let[@inline] watchers (self : store) (v : t) : t Vec.t = Vec.get self.watches v + +let[@inline] add_watcher (self : store) (v : t) ~watcher : unit = + Vec.push (watchers self v) watcher let pop_new_var self : _ option = if Vec_of.is_empty self.new_vars then @@ -99,6 +123,7 @@ module Store = struct reason = Vec.create (); term = Vec.create (); level = Veci.create (); + watches = Vec.create (); value = Vec.create (); has_value = Bitvec.create (); new_vars = Vec_of.create (); diff --git a/src/cdsat/TVar.mli b/src/cdsat/TVar.mli index 13999160..64b87738 100644 --- a/src/cdsat/TVar.mli +++ b/src/cdsat/TVar.mli @@ -48,5 +48,15 @@ val level : store -> t -> int val value : store -> t -> Value.t option (** Value in the current assignment *) +val set_value : store -> t -> Value.t -> unit +val unset_value : store -> t -> unit + +val watchers : store -> t -> t Vec.t +(** [watchers store t] is a vector of other variables watching [t], + generally updated via {!Watch1} and {!Watch2}. + These other variables are notified when [t] is assigned. *) + +val add_watcher : store -> t -> watcher:t -> unit + val pop_new_var : store -> t option (** Pop a new variable if any, or return [None] *) diff --git a/src/cdsat/watch1.ml b/src/cdsat/watch1.ml new file mode 100644 index 00000000..060bf898 --- /dev/null +++ b/src/cdsat/watch1.ml @@ -0,0 +1,32 @@ +open Watch_utils_ + +type t = TVar.t array + +let dummy = [||] +let make = Array.of_list +let[@inline] make_a a : t = a +let[@inline] iter w k = if Array.length w > 0 then k w.(0) + +let init tst w t ~on_all_set : unit = + let i, all_set = find_watch_ tst w 0 0 in + (* put watch first *) + Util.swap_array w i 0; + TVar.add_watcher tst w.(0) ~watcher:t; + if all_set then on_all_set (); + () + +let update tst w t ~watch ~on_all_set : Watch_res.t = + (* find another watch. If none is present, keep the + current one and call [on_all_set]. *) + assert (w.(0) == watch); + let i, all_set = find_watch_ tst w 0 0 in + if all_set then ( + on_all_set (); + Watch_res.Watch_keep (* just keep current watch *) + ) else ( + (* use [i] as the watch *) + assert (i > 0); + Util.swap_array w i 0; + TVar.add_watcher tst w.(0) ~watcher:t; + Watch_res.Watch_remove + ) diff --git a/src/cdsat/watch1.mli b/src/cdsat/watch1.mli new file mode 100644 index 00000000..e377f613 --- /dev/null +++ b/src/cdsat/watch1.mli @@ -0,0 +1,32 @@ +(** 1-Watch Scheme *) + +type t + +val dummy : t +val make : TVar.t list -> t + +val make_a : TVar.t array -> t +(** owns the array *) + +val iter : t -> TVar.t Iter.t +(** current watch(es) *) + +val init : TVar.store -> t -> TVar.t -> on_all_set:(unit -> unit) -> unit +(** [init tstore w t ~on_all_set] initializes [w] (the watchlist) for + var [t], by finding an unassigned TVar.t in the watchlist and + registering [t] to it. + If all vars are set, then it watches the one with the highest level + and call [on_all_set ()] + *) + +val update : + TVar.store -> + t -> + TVar.t -> + watch:TVar.t -> + on_all_set:(unit -> unit) -> + Watch_res.t +(** [update tstore w t ~watch ~on_all_set] updates [w] after [watch] + has been assigned. It looks for another TVar.t in [w] for [t] to watch. + If all vars are set, then it calls [on_all_set ()] +*) diff --git a/src/cdsat/watch2.ml b/src/cdsat/watch2.ml new file mode 100644 index 00000000..31e44514 --- /dev/null +++ b/src/cdsat/watch2.ml @@ -0,0 +1,56 @@ +open Watch_utils_ + +type t = TVar.t array + +let dummy = [||] +let make = Array.of_list +let[@inline] make_a a : t = a + +let[@inline] iter w k = + if Array.length w > 0 then ( + k w.(0); + if Array.length w > 1 then k w.(1) + ) + +let init tst w t ~on_unit ~on_all_set : unit = + let i0, all_set0 = find_watch_ tst w 0 0 in + Util.swap_array w i0 0; + let i1, all_set1 = find_watch_ tst w 1 0 in + Util.swap_array w i1 1; + TVar.add_watcher tst w.(0) ~watcher:t; + TVar.add_watcher tst w.(1) ~watcher:t; + assert ( + if all_set0 then + all_set1 + else + true); + if all_set0 then + on_all_set () + else if all_set1 then ( + assert (not (TVar.has_value tst w.(0))); + on_unit w.(0) + ); + () + +let update tst w t ~watch ~on_unit ~on_all_set : Watch_res.t = + (* find another watch. If none is present, keep the + current ones and call [on_unit] or [on_all_set]. *) + if w.(0) == watch then + (* ensure that if there is only one watch, it's the first *) + Util.swap_array w 0 1 + else + assert (w.(1) == watch); + let i, all_set1 = find_watch_ tst w 1 0 in + if all_set1 then ( + if TVar.has_value tst w.(0) then + on_all_set () + else + on_unit w.(0); + Watch_res.Watch_keep (* just keep current watch *) + ) else ( + (* use [i] as the second watch *) + assert (i > 1); + Util.swap_array w i 1; + TVar.add_watcher tst w.(1) ~watcher:t; + Watch_res.Watch_remove + ) diff --git a/src/cdsat/watch2.mli b/src/cdsat/watch2.mli new file mode 100644 index 00000000..df21a89e --- /dev/null +++ b/src/cdsat/watch2.mli @@ -0,0 +1,41 @@ +(** 2-Watch Scheme *) + +type t + +val dummy : t +val make : TVar.t list -> t + +val make_a : TVar.t array -> t +(** owns the array *) + +val iter : t -> TVar.t Iter.t +(** current watch(es) *) + +val init : + TVar.store -> + t -> + TVar.t -> + on_unit:(TVar.t -> unit) -> + on_all_set:(unit -> unit) -> + unit +(** [init tstore w t ~on_all_set] initializes [w] (the watchlist) for + var [t], by finding an unassigned var in the watchlist and + registering [t] to it. + If exactly one TVar.t [u] is not set, then it calls [on_unit u]. + If all vars are set, then it watches the one with the highest level + and call [on_all_set ()] + *) + +val update : + TVar.store -> + t -> + TVar.t -> + watch:TVar.t -> + on_unit:(TVar.t -> unit) -> + on_all_set:(unit -> unit) -> + Watch_res.t +(** [update tstore w t ~watch ~on_all_set] updates [w] after [watch] + has been assigned. It looks for another var in [w] for [t] to watch. + If exactly one var [u] is not set, then it calls [on_unit u]. + If all vars are set, then it calls [on_all_set ()] +*) diff --git a/src/cdsat/watch_res.ml b/src/cdsat/watch_res.ml new file mode 100644 index 00000000..f14eeee8 --- /dev/null +++ b/src/cdsat/watch_res.ml @@ -0,0 +1,3 @@ +type t = + | Watch_keep (** Keep the watch *) + | Watch_remove (** Remove the watch *) diff --git a/src/cdsat/watch_utils_.ml b/src/cdsat/watch_utils_.ml new file mode 100644 index 00000000..fd3ea00d --- /dev/null +++ b/src/cdsat/watch_utils_.ml @@ -0,0 +1,16 @@ +(* find a term in [w] that is not assigned, or otherwise, + the one with highest level + @return index of term to watch, and [true] if all are assigned *) +let rec find_watch_ tst w i highest : int * bool = + if i = Array.length w then + highest, true + else if TVar.has_value tst w.(i) then ( + let highest = + if TVar.level tst w.(i) > TVar.level tst w.(highest) then + i + else + highest + in + find_watch_ tst w (i + 1) highest + ) else + i, false