From 6f1abedb449c203927d2c5070512b7463c5fb5cb Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Tue, 1 Nov 2022 22:22:14 -0400 Subject: [PATCH] feat(cdsat): embryo of plugins for bool and UF --- src/cdsat/TLit.ml | 13 +++++ src/cdsat/TLit.mli | 10 ++++ src/cdsat/TVar.ml | 11 +++-- src/cdsat/TVar.mli | 4 +- src/cdsat/core.ml | 78 +++++++++++++++++++----------- src/cdsat/core.mli | 36 +++++++------- src/cdsat/plugin_bool.ml | 70 +++++++++++++++++++++++++++ src/cdsat/plugin_bool.mli | 15 ++++++ src/cdsat/plugin_uninterpreted.ml | 68 ++++++++++++++++++++++++++ src/cdsat/plugin_uninterpreted.mli | 15 ++++++ src/cdsat/reason.ml | 11 +++-- src/cdsat/reason.mli | 5 +- src/cdsat/sidekick_cdsat.ml | 6 +++ src/cdsat/solver.ml | 42 +++++++++++++--- src/cdsat/solver.mli | 8 ++- src/cdsat/watch1.ml | 52 +++++++++++--------- src/cdsat/watch1.mli | 21 ++++---- src/cdsat/watch2.ml | 77 +++++++++++++++-------------- src/cdsat/watch2.mli | 23 +++++---- src/main/main.ml | 7 ++- 20 files changed, 420 insertions(+), 152 deletions(-) create mode 100644 src/cdsat/TLit.ml create mode 100644 src/cdsat/TLit.mli create mode 100644 src/cdsat/plugin_bool.ml create mode 100644 src/cdsat/plugin_bool.mli create mode 100644 src/cdsat/plugin_uninterpreted.ml create mode 100644 src/cdsat/plugin_uninterpreted.mli diff --git a/src/cdsat/TLit.ml b/src/cdsat/TLit.ml new file mode 100644 index 00000000..99fd9750 --- /dev/null +++ b/src/cdsat/TLit.ml @@ -0,0 +1,13 @@ +type t = { var: TVar.t; sign: bool } + +let[@inline] make sign var : t = { sign; var } +let[@inline] neg self = { self with sign = not self.sign } +let[@inline] abs self = { self with sign = true } +let[@inline] sign self = self.sign +let[@inline] var self = self.var + +let pp vst out (self : t) = + if self.sign then + TVar.pp vst out self.var + else + Fmt.fprintf out "(@[not@ %a@])" (TVar.pp vst) self.var diff --git a/src/cdsat/TLit.mli b/src/cdsat/TLit.mli new file mode 100644 index 00000000..d85d0865 --- /dev/null +++ b/src/cdsat/TLit.mli @@ -0,0 +1,10 @@ +(** Literal for {!TVar.t} *) + +type t = { var: TVar.t; sign: bool } + +val make : bool -> TVar.t -> t +val neg : t -> t +val abs : t -> t +val sign : t -> bool +val var : t -> TVar.t +val pp : TVar.store -> t Fmt.printer diff --git a/src/cdsat/TVar.ml b/src/cdsat/TVar.ml index 97489092..0d62b2e0 100644 --- a/src/cdsat/TVar.ml +++ b/src/cdsat/TVar.ml @@ -11,10 +11,11 @@ module Vec_of = Veci next [new_var_] allocation *) type reason = - | Decide + | Decide of { level: int } | Propagate of { level: int; hyps: Vec_of.t; proof: Sidekick_proof.step_id } let dummy_level_ = -1 +let dummy_reason_ : reason = Decide { level = dummy_level_ } type store = { tst: Term.store; @@ -50,7 +51,7 @@ let new_var_ (self : store) ~term:(term_for_v : Term.t) ~theory_view () : t = Veci.push level dummy_level_; Vec.push value None; (* fake *) - Vec.push reason Decide; + Vec.push reason dummy_reason_; Vec.push watches (Vec.create ()); Vec.push theory_views theory_view; Bitvec.ensure_size has_value (v + 1); @@ -83,7 +84,7 @@ let[@inline] add_watcher (self : store) (v : t) ~watcher : unit = let assign (self : store) (v : t) ~value ~level ~reason : unit = Log.debugf 50 (fun k -> - k "(@[cdsat.assign[lvl=%d]@ %a@ :val %a@])" level + k "(@[cdsat.tvar.assign[lvl=%d]@ %a@ :val %a@])" level (Term.pp_limit ~max_depth:5 ~max_nodes:30) (term self v) Term.pp value); assert (level >= 0); @@ -94,7 +95,7 @@ let assign (self : store) (v : t) ~value ~level ~reason : unit = let unassign (self : store) (v : t) : unit = Vec.set self.value v None; Veci.set self.level v dummy_level_; - Vec.set self.reason v Decide + Vec.set self.reason v dummy_reason_ let pop_new_var self : _ option = if Vec_of.is_empty self.new_vars then @@ -105,6 +106,8 @@ let pop_new_var self : _ option = module Store = struct type t = store + let tst self = self.tst + let create tst : t = { tst; diff --git a/src/cdsat/TVar.mli b/src/cdsat/TVar.mli index c1b9ce32..b1a8dc5b 100644 --- a/src/cdsat/TVar.mli +++ b/src/cdsat/TVar.mli @@ -14,6 +14,8 @@ type theory_view = .. module Store : sig type t + val tst : t -> Term.store + val create : Term.store -> t (** Create a new store *) end @@ -24,7 +26,7 @@ module Vec_of : Vec_sig.S with type elt := t type store = Store.t type reason = - | Decide + | Decide of { level: int } | Propagate of { level: int; hyps: Vec_of.t; proof: Sidekick_proof.step_id } val get_of_term : store -> Term.t -> t option diff --git a/src/cdsat/core.ml b/src/cdsat/core.ml index 36b00fb8..83ad0da8 100644 --- a/src/cdsat/core.ml +++ b/src/cdsat/core.ml @@ -8,40 +8,32 @@ module type ARG = sig (** build a disjunction *) end -module Plugin_action = struct - type t = { propagate: TVar.t -> Value.t -> Reason.t -> unit } - - let propagate (self : t) var v reas : unit = self.propagate var v reas -end - -(** Core plugin *) -module Plugin = struct - type t = { - name: string; - push_level: unit -> unit; - pop_levels: int -> unit; - decide: TVar.t -> Value.t option; - propagate: Plugin_action.t -> TVar.t -> Value.t -> unit; - term_to_var_hooks: Term_to_var.hook list; - } - - let make ~name ~push_level ~pop_levels ~decide ~propagate ~term_to_var_hooks - () : t = - { name; push_level; pop_levels; decide; propagate; term_to_var_hooks } -end - type t = { tst: Term.store; vst: TVar.store; arg: (module ARG); stats: Stat.t; trail: Trail.t; - plugins: Plugin.t Vec.t; + plugins: plugin Vec.t; term_to_var: Term_to_var.t; mutable last_res: Check_res.t option; proof_tracer: Proof.Tracer.t; } +and plugin_action = t + +and plugin = + | P : { + st: 'st; + name: string; + push_level: 'st -> unit; + pop_levels: 'st -> int -> unit; + decide: 'st -> TVar.t -> Value.t option; + propagate: 'st -> plugin_action -> TVar.t -> Value.t -> unit; + term_to_var_hooks: 'st -> Term_to_var.hook list; + } + -> plugin + let create ?(stats = Stat.create ()) ~arg tst vst ~proof_tracer () : t = { tst; @@ -61,6 +53,21 @@ let[@inline] tst self = self.tst let[@inline] vst self = self.vst let[@inline] last_res self = self.last_res +(* plugins *) + +module Plugin = struct + type t = plugin + type builder = TVar.store -> t + + let[@inline] name (P p) = p.name + + let make_builder ~name ~create ~push_level ~pop_levels ~decide ~propagate + ~term_to_var_hooks () : builder = + fun vst -> + let st = create vst in + P { name; st; push_level; pop_levels; decide; propagate; term_to_var_hooks } +end + (* backtracking *) let n_levels (self : t) : int = Trail.n_levels self.trail @@ -68,14 +75,14 @@ let n_levels (self : t) : int = Trail.n_levels self.trail let push_level (self : t) : unit = Log.debugf 50 (fun k -> k "(@[cdsat.core.push-level@])"); Trail.push_level self.trail; - Vec.iter self.plugins ~f:(fun (p : Plugin.t) -> p.push_level ()); + Vec.iter self.plugins ~f:(fun (P p) -> p.push_level p.st); () let pop_levels (self : t) n : unit = Log.debugf 50 (fun k -> k "(@[cdsat.core.pop-levels %d@])" n); if n > 0 then self.last_res <- None; Trail.pop_levels self.trail n ~f:(fun v -> TVar.unassign self.vst v); - Vec.iter self.plugins ~f:(fun (p : Plugin.t) -> p.pop_levels n); + Vec.iter self.plugins ~f:(fun (P p) -> p.pop_levels p.st n); () (* term to var *) @@ -89,9 +96,10 @@ let add_term_to_var_hook self h = Term_to_var.add_hook self.term_to_var h (* plugins *) -let add_plugin self p = - Vec.push self.plugins p; - List.iter (add_term_to_var_hook self) p.term_to_var_hooks +let add_plugin self (pb : Plugin.builder) : unit = + let (P p as plugin) = pb self.vst in + Vec.push self.plugins plugin; + List.iter (add_term_to_var_hook self) (p.term_to_var_hooks p.st) (* solving *) @@ -100,7 +108,8 @@ let add_ty (_self : t) ~ty:_ : unit = () let assign (self : t) (v : TVar.t) ~(value : Value.t) ~level:v_level ~reason : unit = Log.debugf 50 (fun k -> - k "(@[cdsat.core.assign@ %a@ <- %a@])" (TVar.pp self.vst) v Value.pp value); + k "(@[cdsat.core.assign@ `%a`@ @[<- %a@]@ :reason %a@])" + (TVar.pp self.vst) v Value.pp value Reason.pp reason); self.last_res <- None; match TVar.value self.vst v with | None -> @@ -118,3 +127,14 @@ let solve ~on_exit ~on_progress ~should_stop ~assumptions (self : t) : (* TODO: outer loop (propagate; decide)* *) (* TODO: propagation loop, involving plugins *) assert false + +(* plugin actions *) + +module Plugin_action = struct + type t = plugin_action + + let[@inline] propagate (self : t) var value reason : unit = + assign self var ~value ~level:(Reason.level reason) ~reason + + let term_to_var = term_to_var +end diff --git a/src/cdsat/core.mli b/src/cdsat/core.mli index 747dd85c..3ce21cae 100644 --- a/src/cdsat/core.mli +++ b/src/cdsat/core.mli @@ -15,28 +15,26 @@ module Plugin_action : sig type t val propagate : t -> TVar.t -> Value.t -> Reason.t -> unit + val term_to_var : t -> Term.t -> TVar.t end (** Core plugin *) module Plugin : sig - type t = private { - name: string; - push_level: unit -> unit; - pop_levels: int -> unit; - decide: TVar.t -> Value.t option; - propagate: Plugin_action.t -> TVar.t -> Value.t -> unit; - term_to_var_hooks: Term_to_var.hook list; - } + type t + type builder - val make : + val name : t -> string + + val make_builder : name:string -> - push_level:(unit -> unit) -> - pop_levels:(int -> unit) -> - decide:(TVar.t -> Value.t option) -> - propagate:(Plugin_action.t -> TVar.t -> Value.t -> unit) -> - term_to_var_hooks:Term_to_var.hook list -> + create:(TVar.store -> 'st) -> + push_level:('st -> unit) -> + pop_levels:('st -> int -> unit) -> + decide:('st -> TVar.t -> Value.t option) -> + propagate:('st -> Plugin_action.t -> TVar.t -> Value.t -> unit) -> + term_to_var_hooks:('st -> Term_to_var.hook list) -> unit -> - t + builder end (** {2 Basics} *) @@ -56,7 +54,7 @@ val create : val tst : t -> Term.store val vst : t -> TVar.store val trail : t -> Trail.t -val add_plugin : t -> Plugin.t -> unit +val add_plugin : t -> Plugin.builder -> unit val iter_plugins : t -> Plugin.t Iter.t val last_res : t -> Check_res.t option @@ -78,9 +76,9 @@ val assign : t -> TVar.t -> value:Value.t -> level:int -> reason:Reason.t -> unit val solve : - on_exit:(unit -> unit) -> + on_exit:(unit -> unit) list -> on_progress:(unit -> unit) -> - should_stop:(unit -> bool) -> - assumptions:Term.t list -> + should_stop:(int -> bool) -> + assumptions:Lit.t list -> t -> Check_res.t diff --git a/src/cdsat/plugin_bool.ml b/src/cdsat/plugin_bool.ml new file mode 100644 index 00000000..d80a8975 --- /dev/null +++ b/src/cdsat/plugin_bool.ml @@ -0,0 +1,70 @@ +type 'a view = 'a Bool_view.t + +(** Argument to the plugin *) +module type ARG = sig + val view : Term.t -> Term.t view + val or_l : Term.store -> Term.t list -> Term.t + val and_l : Term.store -> Term.t list -> Term.t +end + +(* our custom view of terms *) +type TVar.theory_view += + | T_bool of bool + | T_and of TLit.t array + | T_or of TLit.t array + +(* our internal state *) +type t = { arg: (module ARG); tst: Term.store; vst: TVar.store } + +let push_level (self : t) = () +let pop_levels (self : t) n = () + +let decide (self : t) (v : TVar.t) : Value.t option = + match TVar.theory_view self.vst v with + | T_bool b -> + (* FIXME: this should be propagated earlier, shouldn't it? *) + Some (Term.bool_val self.tst b) + | T_and _ | T_or _ -> + (* TODO: phase saving? or is it done directly in the core? *) + Some (Term.true_ self.tst) + | _ -> None + +let propagate (self : t) (act : Core.Plugin_action.t) (v : TVar.t) + (value : Value.t) : unit = + () +(* TODO: BCP *) + +let term_to_var_hooks (self : t) : _ list = + let (module A) = self.arg in + + let rec to_tlit t2v (t : Term.t) : TLit.t = + match A.view t with + | Bool_view.B_not u -> + let lit = to_tlit t2v u in + TLit.neg lit + | _ -> + let v = Term_to_var.convert t2v t in + TLit.make true v + in + + (* main hook to convert formulas *) + let h t2v (t : Term.t) : _ option = + match A.view t with + | Bool_view.B_bool b -> Some (T_bool b) + | Bool_view.B_and l -> + let lits = Util.array_of_list_map (to_tlit t2v) l in + Some (T_and lits) + | Bool_view.B_or l -> + let lits = Util.array_of_list_map (to_tlit t2v) l in + Some (T_or lits) + | _ -> None + in + [ h ] + +let builder ((module A : ARG) as arg) : Core.Plugin.builder = + let create vst : t = + let tst = TVar.Store.tst vst in + { arg; vst; tst } + in + Core.Plugin.make_builder ~name:"bool" ~create ~push_level ~pop_levels ~decide + ~propagate ~term_to_var_hooks () diff --git a/src/cdsat/plugin_bool.mli b/src/cdsat/plugin_bool.mli new file mode 100644 index 00000000..b837566f --- /dev/null +++ b/src/cdsat/plugin_bool.mli @@ -0,0 +1,15 @@ +(** Plugin for boolean formulas *) + +open Sidekick_core + +type 'a view = 'a Bool_view.t + +(** Argument to the plugin *) +module type ARG = sig + val view : Term.t -> Term.t view + val or_l : Term.store -> Term.t list -> Term.t + val and_l : Term.store -> Term.t list -> Term.t +end + +val builder : (module ARG) -> Core.Plugin.builder +(** Create a new plugin *) diff --git a/src/cdsat/plugin_uninterpreted.ml b/src/cdsat/plugin_uninterpreted.ml new file mode 100644 index 00000000..f2ebf660 --- /dev/null +++ b/src/cdsat/plugin_uninterpreted.ml @@ -0,0 +1,68 @@ +(** Plugin for uninterpreted symbols *) + +open Sidekick_core + +module type ARG = sig + val is_unin_function : Term.t -> bool +end + +(* store data for each unin function application *) +type TVar.theory_view += + | Unin_const of Term.t + | Unin_fun of { f: Term.t; args: TVar.t array } + +(* congruence table *) +module Cong_tbl = Backtrackable_tbl.Make (struct + type t = { f: Term.t; args: Value.t array } + + let equal a b = Term.equal a.f b.f && CCArray.equal Value.equal a.args b.args + let hash a = CCHash.(combine2 (Term.hash a.f) (array Value.hash a.args)) +end) + +(* an entry [f(values) -> value], used to detect congruence rule *) +type cong_entry = { v_args: TVar.t array; res: Value.t; v_res: TVar.t } + +type t = { + arg: (module ARG); + vst: TVar.store; + cong_table: cong_entry Cong_tbl.t; +} + +let create arg vst : t = { arg; vst; cong_table = Cong_tbl.create ~size:256 () } +let push_level self = Cong_tbl.push_level self.cong_table +let pop_levels self n = Cong_tbl.pop_levels self.cong_table n + +(* let other theories decide, depending on the type *) +let decide _ _ = None + +let propagate (self : t) _act (v : TVar.t) (value : Value.t) = + match TVar.theory_view self.vst v with + | Unin_const _ -> () + | Unin_fun { f = _; args } -> + (* TODO: update congruence table *) + Log.debugf 1 (fun k -> k "FIXME: congruence rule"); + () + | _ -> () + +(* handle new terms *) +let term_to_var_hooks (self : t) : _ list = + let (module A) = self.arg in + let h t2v (t : Term.t) : _ option = + let f, args = Term.unfold_app t in + if A.is_unin_function f then ( + (* convert arguments to vars *) + let args = Util.array_of_list_map (Term_to_var.convert t2v) args in + if Array.length args = 0 then + Some (Unin_const t) + else + Some (Unin_fun { f; args }) + ) else + None + in + [ h ] + +(* TODO: congruence rules *) + +let builder ((module A : ARG) as arg) : Core.Plugin.builder = + Core.Plugin.make_builder ~name:"uf" ~create:(create arg) ~push_level + ~pop_levels ~decide ~propagate ~term_to_var_hooks () diff --git a/src/cdsat/plugin_uninterpreted.mli b/src/cdsat/plugin_uninterpreted.mli new file mode 100644 index 00000000..8ffd2f01 --- /dev/null +++ b/src/cdsat/plugin_uninterpreted.mli @@ -0,0 +1,15 @@ +(** Plugin for uninterpreted symbols *) + +open Sidekick_core + +(** Argument to the plugin *) +module type ARG = sig + val is_unin_function : Term.t -> bool + (** [is_unin_function t] should be true iff [t] is a function symbol + or constant symbol that is uninterpreted + (possibly applied to {b type} arguments in the case of a polymorphic + function/constant). *) +end + +val builder : (module ARG) -> Core.Plugin.builder +(** Create a new plugin *) diff --git a/src/cdsat/reason.ml b/src/cdsat/reason.ml index 2bf11b5a..0a75fc76 100644 --- a/src/cdsat/reason.ml +++ b/src/cdsat/reason.ml @@ -1,5 +1,5 @@ type t = TVar.reason = - | Decide + | Decide of { level: int } | Propagate of { level: int; hyps: TVar.Vec_of.t; @@ -8,14 +8,19 @@ type t = TVar.reason = let pp out (self : t) : unit = match self with - | Decide -> Fmt.string out "decide" + | Decide { level } -> Fmt.fprintf out "decide[lvl=%d]" level | Propagate { level; hyps; proof = _ } -> Fmt.fprintf out "(@[propagate[lvl=%d]@ :n-hyps %d@])" level (TVar.Vec_of.size hyps) -let decide : t = Decide +let[@inline] decide level : t = Decide { level } let[@inline] propagate_ level v proof : t = Propagate { level; hyps = v; proof } +let[@inline] level self = + match self with + | Decide d -> d.level + | Propagate p -> p.level + let propagate_v store v proof : t = let level = TVar.Vec_of.fold_left (fun l v -> max l (TVar.level store v)) 0 v diff --git a/src/cdsat/reason.mli b/src/cdsat/reason.mli index 53c8a123..2e0761ae 100644 --- a/src/cdsat/reason.mli +++ b/src/cdsat/reason.mli @@ -2,7 +2,7 @@ (** Reason for assignment *) type t = TVar.reason = - | Decide + | Decide of { level: int } | Propagate of { level: int; hyps: TVar.Vec_of.t; @@ -11,6 +11,7 @@ type t = TVar.reason = include Sidekick_sigs.PRINT with type t := t -val decide : t +val decide : int -> t val propagate_v : TVar.store -> TVar.Vec_of.t -> Sidekick_proof.step_id -> t val propagate_l : TVar.store -> TVar.t list -> Sidekick_proof.step_id -> t +val level : t -> int diff --git a/src/cdsat/sidekick_cdsat.ml b/src/cdsat/sidekick_cdsat.ml index d0c5a9c4..4245b995 100644 --- a/src/cdsat/sidekick_cdsat.ml +++ b/src/cdsat/sidekick_cdsat.ml @@ -2,8 +2,14 @@ module Trail = Trail module TVar = TVar +module TLit = TLit module Reason = Reason module Value = Value module Core = Core module Solver = Solver module Term_to_var = Term_to_var + +(** {2 Builtin plugins} *) + +module Plugin_bool = Plugin_bool +module Plugin_uninterpreted = Plugin_uninterpreted diff --git a/src/cdsat/solver.ml b/src/cdsat/solver.ml index d5105e7d..554db36a 100644 --- a/src/cdsat/solver.ml +++ b/src/cdsat/solver.ml @@ -5,19 +5,31 @@ module Check_res = Asolver.Check_res module Plugin_action = Core.Plugin_action module Plugin = Core.Plugin -module type ARG = Core.ARG +module type ARG = sig + module Core : Core.ARG + module Bool : Plugin_bool.ARG + module UF : Plugin_uninterpreted.ARG +end type t = { tst: Term.store; vst: TVar.store; core: Core.t; stats: Stat.t; + arg: (module ARG); proof_tracer: Proof.Tracer.t; } -let create ?(stats = Stat.create ()) ~arg tst vst ~proof_tracer () : t = - let core = Core.create ~stats ~arg tst vst ~proof_tracer () in - { tst; vst; core; stats; proof_tracer } +let create ?(stats = Stat.create ()) ~(arg : (module ARG)) tst vst ~proof_tracer + () : t = + let (module A) = arg in + let core = + Core.create ~stats ~arg:(module A.Core : Core.ARG) tst vst ~proof_tracer () + in + Core.add_plugin core (Plugin_bool.builder (module A.Bool : Plugin_bool.ARG)); + Core.add_plugin core + (Plugin_uninterpreted.builder (module A.UF : Plugin_uninterpreted.ARG)); + { tst; vst; arg; core; stats; proof_tracer } let[@inline] core self = self.core let add_plugin self p = Core.add_plugin self.core p @@ -56,13 +68,27 @@ let assert_term self t : unit = in assert_term_ self t pr -let assert_clause (self : t) lits p : unit = assert false (* TODO *) +let assert_clause (self : t) (lits : Lit.t array) p : unit = + let (module A) = self.arg in + (* turn literals into a or-term *) + let args = + Util.array_to_list_map + (fun lit -> + let t = Lit.term lit in + if Lit.sign lit then + t + else + Term.not self.tst t) + lits + in + let t = A.Core.or_l self.tst args in + assert_term_ self t p let pp_stats out self = Stat.pp out self.stats -let solve ?on_exit ?on_progress ?should_stop ~assumptions (self : t) : - Check_res.t = - assert false +let solve ?(on_exit = []) ?(on_progress = ignore) + ?(should_stop = fun _ -> false) ~assumptions (self : t) : Check_res.t = + Core.solve self.core ~on_exit ~on_progress ~should_stop ~assumptions (* asolver interface *) diff --git a/src/cdsat/solver.mli b/src/cdsat/solver.mli index 8b50e8d1..6b56ff6b 100644 --- a/src/cdsat/solver.mli +++ b/src/cdsat/solver.mli @@ -9,7 +9,11 @@ open Sidekick_proof module Plugin_action = Core.Plugin_action module Plugin = Core.Plugin -module type ARG = Core.ARG +module type ARG = sig + module Core : Core.ARG + module Bool : Plugin_bool.ARG + module UF : Plugin_uninterpreted.ARG +end (** {2 Basics} *) @@ -28,7 +32,7 @@ val create : val tst : t -> Term.store val vst : t -> TVar.store val core : t -> Core.t -val add_plugin : t -> Plugin.t -> unit +val add_plugin : t -> Plugin.builder -> unit val iter_plugins : t -> Plugin.t Iter.t (** {2 Solving} *) diff --git a/src/cdsat/watch1.ml b/src/cdsat/watch1.ml index 060bf898..1a9b64e5 100644 --- a/src/cdsat/watch1.ml +++ b/src/cdsat/watch1.ml @@ -1,32 +1,38 @@ open Watch_utils_ -type t = TVar.t array +type t = { vst: TVar.store; arr: TVar.t array; mutable alive: bool } -let dummy = [||] -let make = Array.of_list -let[@inline] make_a a : t = a -let[@inline] iter w k = if Array.length w > 0 then k w.(0) +let make vst l = { alive = true; vst; arr = Array.of_list l } +let[@inline] make_a vst arr : t = { alive = true; vst; arr } +let[@inline] alive self = self.alive +let[@inline] kill self = self.alive <- false -let init tst w t ~on_all_set : unit = - let i, all_set = find_watch_ tst w 0 0 in +let[@inline] iter (self : t) k = + if Array.length self.arr > 0 then k self.arr.(0) + +let init (self : t) t ~on_all_set : unit = + let i, all_set = find_watch_ self.vst self.arr 0 0 in (* put watch first *) - Util.swap_array w i 0; - TVar.add_watcher tst w.(0) ~watcher:t; + Util.swap_array self.arr i 0; + TVar.add_watcher self.vst self.arr.(0) ~watcher:t; if all_set then on_all_set (); () -let update tst w t ~watch ~on_all_set : Watch_res.t = - (* find another watch. If none is present, keep the - current one and call [on_all_set]. *) - assert (w.(0) == watch); - let i, all_set = find_watch_ tst w 0 0 in - if all_set then ( - on_all_set (); - Watch_res.Watch_keep (* just keep current watch *) - ) else ( - (* use [i] as the watch *) - assert (i > 0); - Util.swap_array w i 0; - TVar.add_watcher tst w.(0) ~watcher:t; +let update (self : t) t ~watch ~on_all_set : Watch_res.t = + if self.alive then ( + (* find another watch. If none is present, keep the + current one and call [on_all_set]. *) + assert (self.arr.(0) == watch); + let i, all_set = find_watch_ self.vst self.arr 0 0 in + if all_set then ( + on_all_set (); + Watch_res.Watch_keep (* just keep current watch *) + ) else ( + (* use [i] as the watch *) + assert (i > 0); + Util.swap_array self.arr i 0; + TVar.add_watcher self.vst self.arr.(0) ~watcher:t; + Watch_res.Watch_remove + ) + ) else Watch_res.Watch_remove - ) diff --git a/src/cdsat/watch1.mli b/src/cdsat/watch1.mli index e377f613..aa26a63f 100644 --- a/src/cdsat/watch1.mli +++ b/src/cdsat/watch1.mli @@ -2,17 +2,19 @@ type t -val dummy : t -val make : TVar.t list -> t +val make : TVar.store -> TVar.t list -> t -val make_a : TVar.t array -> t +val make_a : TVar.store -> TVar.t array -> t (** owns the array *) +val alive : t -> bool +val kill : t -> unit + val iter : t -> TVar.t Iter.t (** current watch(es) *) -val init : TVar.store -> t -> TVar.t -> on_all_set:(unit -> unit) -> unit -(** [init tstore w t ~on_all_set] initializes [w] (the watchlist) for +val init : t -> TVar.t -> on_all_set:(unit -> unit) -> unit +(** [init w t ~on_all_set] initializes [w] (the watchlist) for var [t], by finding an unassigned TVar.t in the watchlist and registering [t] to it. If all vars are set, then it watches the one with the highest level @@ -20,13 +22,8 @@ val init : TVar.store -> t -> TVar.t -> on_all_set:(unit -> unit) -> unit *) val update : - TVar.store -> - t -> - TVar.t -> - watch:TVar.t -> - on_all_set:(unit -> unit) -> - Watch_res.t -(** [update tstore w t ~watch ~on_all_set] updates [w] after [watch] + t -> TVar.t -> watch:TVar.t -> on_all_set:(unit -> unit) -> Watch_res.t +(** [update w t ~watch ~on_all_set] updates [w] after [watch] has been assigned. It looks for another TVar.t in [w] for [t] to watch. If all vars are set, then it calls [on_all_set ()] *) diff --git a/src/cdsat/watch2.ml b/src/cdsat/watch2.ml index 31e44514..b25765c1 100644 --- a/src/cdsat/watch2.ml +++ b/src/cdsat/watch2.ml @@ -1,24 +1,26 @@ open Watch_utils_ -type t = TVar.t array +type t = { vst: TVar.store; arr: TVar.t array; mutable alive: bool } let dummy = [||] -let make = Array.of_list -let[@inline] make_a a : t = a +let make vst l : t = { alive = true; vst; arr = Array.of_list l } +let[@inline] make_a vst arr : t = { vst; arr; alive = true } +let[@inline] alive self = self.alive +let[@inline] kill self = self.alive <- false -let[@inline] iter w k = - if Array.length w > 0 then ( - k w.(0); - if Array.length w > 1 then k w.(1) +let[@inline] iter (self : t) k = + if Array.length self.arr > 0 then ( + k self.arr.(0); + if Array.length self.arr > 1 then k self.arr.(1) ) -let init tst w t ~on_unit ~on_all_set : unit = - let i0, all_set0 = find_watch_ tst w 0 0 in - Util.swap_array w i0 0; - let i1, all_set1 = find_watch_ tst w 1 0 in - Util.swap_array w i1 1; - TVar.add_watcher tst w.(0) ~watcher:t; - TVar.add_watcher tst w.(1) ~watcher:t; +let init (self : t) t ~on_unit ~on_all_set : unit = + let i0, all_set0 = find_watch_ self.vst self.arr 0 0 in + Util.swap_array self.arr i0 0; + let i1, all_set1 = find_watch_ self.vst self.arr 1 0 in + Util.swap_array self.arr i1 1; + TVar.add_watcher self.vst self.arr.(0) ~watcher:t; + TVar.add_watcher self.vst self.arr.(1) ~watcher:t; assert ( if all_set0 then all_set1 @@ -27,30 +29,33 @@ let init tst w t ~on_unit ~on_all_set : unit = if all_set0 then on_all_set () else if all_set1 then ( - assert (not (TVar.has_value tst w.(0))); - on_unit w.(0) + assert (not (TVar.has_value self.vst self.arr.(0))); + on_unit self.arr.(0) ); () -let update tst w t ~watch ~on_unit ~on_all_set : Watch_res.t = - (* find another watch. If none is present, keep the - current ones and call [on_unit] or [on_all_set]. *) - if w.(0) == watch then - (* ensure that if there is only one watch, it's the first *) - Util.swap_array w 0 1 - else - assert (w.(1) == watch); - let i, all_set1 = find_watch_ tst w 1 0 in - if all_set1 then ( - if TVar.has_value tst w.(0) then - on_all_set () +let update (self : t) t ~watch ~on_unit ~on_all_set : Watch_res.t = + if self.alive then ( + (* find another watch. If none is present, keep the + current ones and call [on_unit] or [on_all_set]. *) + if self.arr.(0) == watch then + (* ensure that if there is only one watch, it's the first *) + Util.swap_array self.arr 0 1 else - on_unit w.(0); - Watch_res.Watch_keep (* just keep current watch *) - ) else ( - (* use [i] as the second watch *) - assert (i > 1); - Util.swap_array w i 1; - TVar.add_watcher tst w.(1) ~watcher:t; + assert (self.arr.(1) == watch); + let i, all_set1 = find_watch_ self.vst self.arr 1 0 in + if all_set1 then ( + if TVar.has_value self.vst self.arr.(0) then + on_all_set () + else + on_unit self.arr.(0); + Watch_res.Watch_keep (* just keep current watch *) + ) else ( + (* use [i] as the second watch *) + assert (i > 1); + Util.swap_array self.arr i 1; + TVar.add_watcher self.vst self.arr.(1) ~watcher:t; + Watch_res.Watch_remove + ) + ) else Watch_res.Watch_remove - ) diff --git a/src/cdsat/watch2.mli b/src/cdsat/watch2.mli index df21a89e..3400ab5f 100644 --- a/src/cdsat/watch2.mli +++ b/src/cdsat/watch2.mli @@ -2,23 +2,23 @@ type t -val dummy : t -val make : TVar.t list -> t +val make : TVar.store -> TVar.t list -> t -val make_a : TVar.t array -> t +val make_a : TVar.store -> TVar.t array -> t (** owns the array *) val iter : t -> TVar.t Iter.t (** current watch(es) *) +val kill : t -> unit +(** Disable the watch. It will be removed lazily. *) + +val alive : t -> bool +(** Is the watch alive? *) + val init : - TVar.store -> - t -> - TVar.t -> - on_unit:(TVar.t -> unit) -> - on_all_set:(unit -> unit) -> - unit -(** [init tstore w t ~on_all_set] initializes [w] (the watchlist) for + t -> TVar.t -> on_unit:(TVar.t -> unit) -> on_all_set:(unit -> unit) -> unit +(** [init w t ~on_all_set] initializes [w] (the watchlist) for var [t], by finding an unassigned var in the watchlist and registering [t] to it. If exactly one TVar.t [u] is not set, then it calls [on_unit u]. @@ -27,14 +27,13 @@ val init : *) val update : - TVar.store -> t -> TVar.t -> watch:TVar.t -> on_unit:(TVar.t -> unit) -> on_all_set:(unit -> unit) -> Watch_res.t -(** [update tstore w t ~watch ~on_all_set] updates [w] after [watch] +(** [update w t ~watch ~on_all_set] updates [w] after [watch] has been assigned. It looks for another var in [w] for [t] to watch. If exactly one var [u] is not set, then it calls [on_unit u]. If all vars are set, then it calls [on_all_set ()] diff --git a/src/main/main.ml b/src/main/main.ml index 61e2b927..aaab5779 100644 --- a/src/main/main.ml +++ b/src/main/main.ml @@ -191,7 +191,12 @@ let main_smt ~config () : _ result = let vst = TVar.Store.create tst in let arg = (module struct - let or_l = Sidekick_base.Form.or_l + module Core = Sidekick_base.Form + module Bool = Sidekick_base.Form + + module UF = struct + let is_unin_function = Sidekick_base.Uconst.is_uconst + end end : Solver.ARG) in Solver.create tst vst ~arg ~proof_tracer:(tracer :> Proof.Tracer.t) ()