From 6374fd7d5f71c2e59285822a4af70f67ba71d051 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Fri, 4 Nov 2022 22:04:42 -0400 Subject: [PATCH] wip(cdsat): start implementing propagation --- src/cdsat/conflict.ml | 6 + src/cdsat/conflict.mli | 7 ++ src/cdsat/core.ml | 176 +++++++++++++++++++++++++++--- src/cdsat/core.mli | 2 +- src/cdsat/plugin_bool.ml | 2 + src/cdsat/plugin_uninterpreted.ml | 2 +- src/cdsat/solver.ml | 2 +- src/cdsat/trail.ml | 18 ++- src/cdsat/trail.mli | 5 +- src/core-logic/t_builtins.ml | 2 + src/core-logic/t_builtins.mli | 7 ++ 11 files changed, 204 insertions(+), 25 deletions(-) create mode 100644 src/cdsat/conflict.ml create mode 100644 src/cdsat/conflict.mli diff --git a/src/cdsat/conflict.ml b/src/cdsat/conflict.ml new file mode 100644 index 00000000..0f84d191 --- /dev/null +++ b/src/cdsat/conflict.ml @@ -0,0 +1,6 @@ +type t = { vst: TVar.store; lit: TLit.t; propagate_reason: Reason.t } + +let pp out (self : t) = + Fmt.fprintf out "(@[conflict %a@])" (TLit.pp self.vst) self.lit + +let make vst ~lit ~propagate_reason () : t = { vst; lit; propagate_reason } diff --git a/src/cdsat/conflict.mli b/src/cdsat/conflict.mli new file mode 100644 index 00000000..3b1158c2 --- /dev/null +++ b/src/cdsat/conflict.mli @@ -0,0 +1,7 @@ +(** Conflict discovered during search *) + +type t = { vst: TVar.store; lit: TLit.t; propagate_reason: Reason.t } + +include Sidekick_sigs.PRINT with type t := t + +val make : TVar.store -> lit:TLit.t -> propagate_reason:Reason.t -> unit -> t diff --git a/src/cdsat/core.ml b/src/cdsat/core.ml index 83ad0da8..e27c7e15 100644 --- a/src/cdsat/core.ml +++ b/src/cdsat/core.ml @@ -1,6 +1,7 @@ open Sidekick_core module Proof = Sidekick_proof module Check_res = Sidekick_abstract_solver.Check_res +module Unknown = Sidekick_abstract_solver.Unknown (** Argument to the solver *) module type ARG = sig @@ -8,6 +9,16 @@ module type ARG = sig (** build a disjunction *) end +(* TODO: embed a simplifier, and have simplify hooks in plugins. + Then use the simplifier on any asserted term *) + +type pending_assignment = { + var: TVar.t; + value: Value.t; + level: int; + reason: Reason.t; +} + type t = { tst: Term.store; vst: TVar.store; @@ -16,12 +27,26 @@ type t = { trail: Trail.t; plugins: plugin Vec.t; term_to_var: Term_to_var.t; + pending_assignments: pending_assignment Vec.t; mutable last_res: Check_res.t option; proof_tracer: Proof.Tracer.t; + n_conflicts: int Stat.counter; + n_propagations: int Stat.counter; + n_restarts: int Stat.counter; } and plugin_action = t +(* FIXME: + - add [on_add_var: TVar.t -> unit] + and [on_remove_var: TVar.t -> unit]. + these are called when a variable becomes relevant/is removed or GC'd + (in particular: setup watches + var constraints on add, + kill watches and remove constraints on remove) + + - add [gc_mark : TVar.t -> recurse:(TVar.t -> unit) -> unit] + to mark sub-variables during GC mark phase. +*) and plugin = | P : { st: 'st; @@ -42,13 +67,17 @@ let create ?(stats = Stat.create ()) ~arg tst vst ~proof_tracer () : t = stats; trail = Trail.create (); plugins = Vec.create (); + pending_assignments = Vec.create (); term_to_var = Term_to_var.create vst; last_res = None; proof_tracer; + n_restarts = Stat.mk_int stats "cdsat.restarts"; + n_conflicts = Stat.mk_int stats "cdsat.conflicts"; + n_propagations = Stat.mk_int stats "cdsat.propagations"; } let[@inline] trail self = self.trail -let[@inline] iter_plugins self f = Vec.iter ~f self.plugins +let[@inline] iter_plugins self ~f = Vec.iter ~f self.plugins let[@inline] tst self = self.tst let[@inline] vst self = self.vst let[@inline] last_res self = self.last_res @@ -79,10 +108,30 @@ let push_level (self : t) : unit = () let pop_levels (self : t) n : unit = + let { + tst = _; + vst = _; + arg = _; + stats = _; + trail; + plugins; + term_to_var = _; + pending_assignments; + last_res = _; + proof_tracer = _; + n_propagations = _; + n_conflicts = _; + n_restarts = _; + } = + self + in Log.debugf 50 (fun k -> k "(@[cdsat.core.pop-levels %d@])" n); - if n > 0 then self.last_res <- None; - Trail.pop_levels self.trail n ~f:(fun v -> TVar.unassign self.vst v); - Vec.iter self.plugins ~f:(fun (P p) -> p.pop_levels p.st n); + if n > 0 then ( + self.last_res <- None; + Vec.clear pending_assignments + ); + Trail.pop_levels trail n ~f:(fun v -> TVar.unassign self.vst v); + Vec.iter plugins ~f:(fun (P p) -> p.pop_levels p.st n); () (* term to var *) @@ -105,28 +154,123 @@ let add_plugin self (pb : Plugin.builder) : unit = let add_ty (_self : t) ~ty:_ : unit = () +(* Assign [v <- value] for [reason] at [level]. + This assignment is delayed. *) let assign (self : t) (v : TVar.t) ~(value : Value.t) ~level:v_level ~reason : unit = Log.debugf 50 (fun k -> k "(@[cdsat.core.assign@ `%a`@ @[<- %a@]@ :reason %a@])" (TVar.pp self.vst) v Value.pp value Reason.pp reason); self.last_res <- None; - match TVar.value self.vst v with - | None -> - TVar.assign self.vst v ~value ~level:v_level ~reason; - Trail.push_assignment self.trail v - | Some value' when Value.equal value value' -> () (* idempotent *) - | Some value' -> - (* TODO: conflict *) - Log.debugf 0 (fun k -> k "TODO: conflict (incompatible values)"); - () + Vec.push self.pending_assignments { var = v; value; level = v_level; reason } + +exception E_conflict of Conflict.t + +let raise_conflict (c : Conflict.t) : 'a = raise (E_conflict c) + +(* add pending assignments to the trail. This might trigger a conflict + in case an assignment contradicts an already existing assignment. *) +let perform_pending_assignments (self : t) : unit = + while not (Vec.is_empty self.pending_assignments) do + let { var = v; level = v_level; value; reason } = + Vec.pop_exn self.pending_assignments + in + match TVar.value self.vst v with + | None -> + TVar.assign self.vst v ~value ~level:v_level ~reason; + Trail.push_assignment self.trail v + | Some value' when Value.equal value value' -> () (* idempotent *) + | Some _value' -> + (* conflict should only occur on booleans since they're the only + propagation-able variables *) + assert (Term.is_a_bool (TVar.term self.vst v)); + Log.debugf 0 (fun k -> + k "TODO: conflict (incompatible values for %a)" (TVar.pp self.vst) v); + raise_conflict + @@ Conflict.make self.vst ~lit:(TLit.make true v) ~propagate_reason:reason + () + done + +let propagate (self : t) : Conflict.t option = + let@ () = Profile.with_ "cdsat.propagate" in + try + let continue = ref true in + while !continue do + perform_pending_assignments self; + + while Trail.head self.trail < Trail.size self.trail do + let var = Trail.get self.trail (Trail.head self.trail) in + + (* TODO: call plugins *) + Log.debugf 0 (fun k -> k "TODO: propagate %a" (TVar.pp self.vst) var); + + let value = + match TVar.value self.vst var with + | Some v -> v + | None -> assert false + in + + iter_plugins self ~f:(fun (P p) -> p.propagate p.st self var value); + + (* move to next var *) + Trail.set_head self.trail (Trail.head self.trail + 1) + done; + + (* did we reach fixpoint? *) + if Vec.is_empty self.pending_assignments then continue := false + done; + None + with E_conflict c -> Some c let solve ~on_exit ~on_progress ~should_stop ~assumptions (self : t) : Check_res.t = + let@ () = Profile.with_ "cdsat.solve" in self.last_res <- None; - (* TODO: outer loop (propagate; decide)* *) - (* TODO: propagation loop, involving plugins *) - assert false + + (* FIXME: handle assumptions. + - do assumptions first when deciding (forced decisions) + - in case of conflict below assumptions len, special conflict analysis to + compute unsat core + *) + + (* control if loop stops *) + let continue = ref true in + let n_conflicts = ref 0 in + let res = ref (Check_res.Unknown Unknown.U_incomplete) in + + (* main loop *) + while !continue do + if !n_conflicts mod 64 = 0 then on_progress (); + + (* propagate *) + (match propagate self with + | Some c -> + Log.debugf 1 (fun k -> + k "(@[cdsat.propagate.found-conflict@ %a@])" Conflict.pp c); + incr n_conflicts; + Stat.incr self.n_conflicts; + + (* TODO: handle conflict, learn a clause or declare unsat *) + (* TODO: see if we want to restart *) + assert false + | None -> + Log.debugf 0 (fun k -> k "TODO: decide"); + (* TODO: decide *) + ()); + + (* regularly check if it's time to stop *) + if !n_conflicts mod 64 = 0 then + if should_stop !n_conflicts then ( + Log.debugf 1 (fun k -> k "(@[cdsat.stop@ :caused-by-callback@])"); + + res := Check_res.Unknown Unknown.U_asked_to_stop; + continue := false + ) + done; + + (* cleanup and exit *) + List.iter (fun f -> f ()) on_exit; + !res (* plugin actions *) diff --git a/src/cdsat/core.mli b/src/cdsat/core.mli index 3ce21cae..7f9caae8 100644 --- a/src/cdsat/core.mli +++ b/src/cdsat/core.mli @@ -55,7 +55,7 @@ val tst : t -> Term.store val vst : t -> TVar.store val trail : t -> Trail.t val add_plugin : t -> Plugin.builder -> unit -val iter_plugins : t -> Plugin.t Iter.t +val iter_plugins : t -> f:(Plugin.t -> unit) -> unit val last_res : t -> Check_res.t option (** Last result. Reset on backtracking/assertion *) diff --git a/src/cdsat/plugin_bool.ml b/src/cdsat/plugin_bool.ml index d80a8975..2e735040 100644 --- a/src/cdsat/plugin_bool.ml +++ b/src/cdsat/plugin_bool.ml @@ -31,6 +31,8 @@ let decide (self : t) (v : TVar.t) : Value.t option = let propagate (self : t) (act : Core.Plugin_action.t) (v : TVar.t) (value : Value.t) : unit = + Log.debugf 0 (fun k -> + k "(@[bool-plugin.propagate %a@])" (TVar.pp self.vst) v); () (* TODO: BCP *) diff --git a/src/cdsat/plugin_uninterpreted.ml b/src/cdsat/plugin_uninterpreted.ml index f2ebf660..759448c1 100644 --- a/src/cdsat/plugin_uninterpreted.ml +++ b/src/cdsat/plugin_uninterpreted.ml @@ -40,7 +40,7 @@ let propagate (self : t) _act (v : TVar.t) (value : Value.t) = | Unin_const _ -> () | Unin_fun { f = _; args } -> (* TODO: update congruence table *) - Log.debugf 1 (fun k -> k "FIXME: congruence rule"); + Log.debugf 0 (fun k -> k "FIXME: congruence rule"); () | _ -> () diff --git a/src/cdsat/solver.ml b/src/cdsat/solver.ml index 554db36a..9560cccc 100644 --- a/src/cdsat/solver.ml +++ b/src/cdsat/solver.ml @@ -33,7 +33,7 @@ let create ?(stats = Stat.create ()) ~(arg : (module ARG)) tst vst ~proof_tracer let[@inline] core self = self.core let add_plugin self p = Core.add_plugin self.core p -let[@inline] iter_plugins self f = Core.iter_plugins self.core f +let[@inline] iter_plugins self f = Core.iter_plugins self.core ~f let[@inline] tst self = self.tst let[@inline] vst self = self.vst diff --git a/src/cdsat/trail.ml b/src/cdsat/trail.ml index 008812e8..dbcf0750 100644 --- a/src/cdsat/trail.ml +++ b/src/cdsat/trail.ml @@ -1,15 +1,21 @@ module VVec = TVar.Vec_of -type t = { vars: VVec.t; levels: Veci.t } +type t = { vars: VVec.t; levels: Veci.t; mutable head: int } -let create () : t = { vars = VVec.create (); levels = Veci.create () } +let create () : t = { vars = VVec.create (); levels = Veci.create (); head = 0 } let[@inline] push_assignment (self : t) (v : TVar.t) : unit = VVec.push self.vars v -let[@inline] var_at (self : t) (i : int) : TVar.t = VVec.get self.vars i +let[@inline] get (self : t) (i : int) : TVar.t = VVec.get self.vars i +let[@inline] size (self : t) : int = VVec.size self.vars let[@inline] n_levels self = Veci.size self.levels -let push_level (self : t) : unit = Veci.push self.levels (VVec.size self.vars) + +let[@inline] push_level (self : t) : unit = + Veci.push self.levels (VVec.size self.vars) + +let[@inline] head self = self.head +let[@inline] set_head self n = self.head <- n let pop_levels (self : t) (n : int) ~f : unit = if n <= n_levels self then ( @@ -18,5 +24,7 @@ let pop_levels (self : t) (n : int) ~f : unit = while VVec.size self.vars > idx do let elt = VVec.pop self.vars in f elt - done + done; + (* also reset head *) + if self.head > idx then self.head <- idx ) diff --git a/src/cdsat/trail.mli b/src/cdsat/trail.mli index 00b8b5a1..877c3663 100644 --- a/src/cdsat/trail.mli +++ b/src/cdsat/trail.mli @@ -3,7 +3,10 @@ type t val create : unit -> t -val var_at : t -> int -> TVar.t +val get : t -> int -> TVar.t +val size : t -> int val push_assignment : t -> TVar.t -> unit +val head : t -> int +val set_head : t -> int -> unit include Sidekick_sigs.BACKTRACKABLE0_CB with type t := t and type elt := TVar.t diff --git a/src/core-logic/t_builtins.ml b/src/core-logic/t_builtins.ml index 019aa762..45e65451 100644 --- a/src/core-logic/t_builtins.ml +++ b/src/core-logic/t_builtins.ml @@ -126,6 +126,8 @@ let is_bool t = | E_const { c_view = C_bool; _ } -> true | _ -> false +let[@inline] is_a_bool t = is_bool (ty t) + let is_true t = match view t with | E_const { c_view = C_true; _ } -> true diff --git a/src/core-logic/t_builtins.mli b/src/core-logic/t_builtins.mli index 01648e9e..4d865507 100644 --- a/src/core-logic/t_builtins.mli +++ b/src/core-logic/t_builtins.mli @@ -24,7 +24,14 @@ val ite : store -> t -> t -> t -> t (** [ite a b c] is [if a then b else c] *) val is_eq : t -> bool +(** [is_eq t] is true if [t] is the [=] constant *) + val is_bool : t -> bool +(** [is_bool t] is true if [t] is the type bool itself *) + +val is_a_bool : t -> bool +(** [is_a_bool t] is true if [t] has type [bool] *) + val is_true : t -> bool val is_false : t -> bool