wip(cdsat): start implementing propagation

This commit is contained in:
Simon Cruanes 2022-11-04 22:04:42 -04:00
parent 6f1abedb44
commit 6374fd7d5f
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
11 changed files with 204 additions and 25 deletions

6
src/cdsat/conflict.ml Normal file
View file

@ -0,0 +1,6 @@
type t = { vst: TVar.store; lit: TLit.t; propagate_reason: Reason.t }
let pp out (self : t) =
Fmt.fprintf out "(@[conflict %a@])" (TLit.pp self.vst) self.lit
let make vst ~lit ~propagate_reason () : t = { vst; lit; propagate_reason }

7
src/cdsat/conflict.mli Normal file
View file

@ -0,0 +1,7 @@
(** Conflict discovered during search *)
type t = { vst: TVar.store; lit: TLit.t; propagate_reason: Reason.t }
include Sidekick_sigs.PRINT with type t := t
val make : TVar.store -> lit:TLit.t -> propagate_reason:Reason.t -> unit -> t

View file

@ -1,6 +1,7 @@
open Sidekick_core open Sidekick_core
module Proof = Sidekick_proof module Proof = Sidekick_proof
module Check_res = Sidekick_abstract_solver.Check_res module Check_res = Sidekick_abstract_solver.Check_res
module Unknown = Sidekick_abstract_solver.Unknown
(** Argument to the solver *) (** Argument to the solver *)
module type ARG = sig module type ARG = sig
@ -8,6 +9,16 @@ module type ARG = sig
(** build a disjunction *) (** build a disjunction *)
end end
(* TODO: embed a simplifier, and have simplify hooks in plugins.
Then use the simplifier on any asserted term *)
type pending_assignment = {
var: TVar.t;
value: Value.t;
level: int;
reason: Reason.t;
}
type t = { type t = {
tst: Term.store; tst: Term.store;
vst: TVar.store; vst: TVar.store;
@ -16,12 +27,26 @@ type t = {
trail: Trail.t; trail: Trail.t;
plugins: plugin Vec.t; plugins: plugin Vec.t;
term_to_var: Term_to_var.t; term_to_var: Term_to_var.t;
pending_assignments: pending_assignment Vec.t;
mutable last_res: Check_res.t option; mutable last_res: Check_res.t option;
proof_tracer: Proof.Tracer.t; proof_tracer: Proof.Tracer.t;
n_conflicts: int Stat.counter;
n_propagations: int Stat.counter;
n_restarts: int Stat.counter;
} }
and plugin_action = t and plugin_action = t
(* FIXME:
- add [on_add_var: TVar.t -> unit]
and [on_remove_var: TVar.t -> unit].
these are called when a variable becomes relevant/is removed or GC'd
(in particular: setup watches + var constraints on add,
kill watches and remove constraints on remove)
- add [gc_mark : TVar.t -> recurse:(TVar.t -> unit) -> unit]
to mark sub-variables during GC mark phase.
*)
and plugin = and plugin =
| P : { | P : {
st: 'st; st: 'st;
@ -42,13 +67,17 @@ let create ?(stats = Stat.create ()) ~arg tst vst ~proof_tracer () : t =
stats; stats;
trail = Trail.create (); trail = Trail.create ();
plugins = Vec.create (); plugins = Vec.create ();
pending_assignments = Vec.create ();
term_to_var = Term_to_var.create vst; term_to_var = Term_to_var.create vst;
last_res = None; last_res = None;
proof_tracer; proof_tracer;
n_restarts = Stat.mk_int stats "cdsat.restarts";
n_conflicts = Stat.mk_int stats "cdsat.conflicts";
n_propagations = Stat.mk_int stats "cdsat.propagations";
} }
let[@inline] trail self = self.trail let[@inline] trail self = self.trail
let[@inline] iter_plugins self f = Vec.iter ~f self.plugins let[@inline] iter_plugins self ~f = Vec.iter ~f self.plugins
let[@inline] tst self = self.tst let[@inline] tst self = self.tst
let[@inline] vst self = self.vst let[@inline] vst self = self.vst
let[@inline] last_res self = self.last_res let[@inline] last_res self = self.last_res
@ -79,10 +108,30 @@ let push_level (self : t) : unit =
() ()
let pop_levels (self : t) n : unit = let pop_levels (self : t) n : unit =
let {
tst = _;
vst = _;
arg = _;
stats = _;
trail;
plugins;
term_to_var = _;
pending_assignments;
last_res = _;
proof_tracer = _;
n_propagations = _;
n_conflicts = _;
n_restarts = _;
} =
self
in
Log.debugf 50 (fun k -> k "(@[cdsat.core.pop-levels %d@])" n); Log.debugf 50 (fun k -> k "(@[cdsat.core.pop-levels %d@])" n);
if n > 0 then self.last_res <- None; if n > 0 then (
Trail.pop_levels self.trail n ~f:(fun v -> TVar.unassign self.vst v); self.last_res <- None;
Vec.iter self.plugins ~f:(fun (P p) -> p.pop_levels p.st n); Vec.clear pending_assignments
);
Trail.pop_levels trail n ~f:(fun v -> TVar.unassign self.vst v);
Vec.iter plugins ~f:(fun (P p) -> p.pop_levels p.st n);
() ()
(* term to var *) (* term to var *)
@ -105,28 +154,123 @@ let add_plugin self (pb : Plugin.builder) : unit =
let add_ty (_self : t) ~ty:_ : unit = () let add_ty (_self : t) ~ty:_ : unit = ()
(* Assign [v <- value] for [reason] at [level].
This assignment is delayed. *)
let assign (self : t) (v : TVar.t) ~(value : Value.t) ~level:v_level ~reason : let assign (self : t) (v : TVar.t) ~(value : Value.t) ~level:v_level ~reason :
unit = unit =
Log.debugf 50 (fun k -> Log.debugf 50 (fun k ->
k "(@[cdsat.core.assign@ `%a`@ @[<- %a@]@ :reason %a@])" k "(@[cdsat.core.assign@ `%a`@ @[<- %a@]@ :reason %a@])"
(TVar.pp self.vst) v Value.pp value Reason.pp reason); (TVar.pp self.vst) v Value.pp value Reason.pp reason);
self.last_res <- None; self.last_res <- None;
Vec.push self.pending_assignments { var = v; value; level = v_level; reason }
exception E_conflict of Conflict.t
let raise_conflict (c : Conflict.t) : 'a = raise (E_conflict c)
(* add pending assignments to the trail. This might trigger a conflict
in case an assignment contradicts an already existing assignment. *)
let perform_pending_assignments (self : t) : unit =
while not (Vec.is_empty self.pending_assignments) do
let { var = v; level = v_level; value; reason } =
Vec.pop_exn self.pending_assignments
in
match TVar.value self.vst v with match TVar.value self.vst v with
| None -> | None ->
TVar.assign self.vst v ~value ~level:v_level ~reason; TVar.assign self.vst v ~value ~level:v_level ~reason;
Trail.push_assignment self.trail v Trail.push_assignment self.trail v
| Some value' when Value.equal value value' -> () (* idempotent *) | Some value' when Value.equal value value' -> () (* idempotent *)
| Some value' -> | Some _value' ->
(* TODO: conflict *) (* conflict should only occur on booleans since they're the only
Log.debugf 0 (fun k -> k "TODO: conflict (incompatible values)"); propagation-able variables *)
assert (Term.is_a_bool (TVar.term self.vst v));
Log.debugf 0 (fun k ->
k "TODO: conflict (incompatible values for %a)" (TVar.pp self.vst) v);
raise_conflict
@@ Conflict.make self.vst ~lit:(TLit.make true v) ~propagate_reason:reason
() ()
done
let propagate (self : t) : Conflict.t option =
let@ () = Profile.with_ "cdsat.propagate" in
try
let continue = ref true in
while !continue do
perform_pending_assignments self;
while Trail.head self.trail < Trail.size self.trail do
let var = Trail.get self.trail (Trail.head self.trail) in
(* TODO: call plugins *)
Log.debugf 0 (fun k -> k "TODO: propagate %a" (TVar.pp self.vst) var);
let value =
match TVar.value self.vst var with
| Some v -> v
| None -> assert false
in
iter_plugins self ~f:(fun (P p) -> p.propagate p.st self var value);
(* move to next var *)
Trail.set_head self.trail (Trail.head self.trail + 1)
done;
(* did we reach fixpoint? *)
if Vec.is_empty self.pending_assignments then continue := false
done;
None
with E_conflict c -> Some c
let solve ~on_exit ~on_progress ~should_stop ~assumptions (self : t) : let solve ~on_exit ~on_progress ~should_stop ~assumptions (self : t) :
Check_res.t = Check_res.t =
let@ () = Profile.with_ "cdsat.solve" in
self.last_res <- None; self.last_res <- None;
(* TODO: outer loop (propagate; decide)* *)
(* TODO: propagation loop, involving plugins *) (* FIXME: handle assumptions.
- do assumptions first when deciding (forced decisions)
- in case of conflict below assumptions len, special conflict analysis to
compute unsat core
*)
(* control if loop stops *)
let continue = ref true in
let n_conflicts = ref 0 in
let res = ref (Check_res.Unknown Unknown.U_incomplete) in
(* main loop *)
while !continue do
if !n_conflicts mod 64 = 0 then on_progress ();
(* propagate *)
(match propagate self with
| Some c ->
Log.debugf 1 (fun k ->
k "(@[cdsat.propagate.found-conflict@ %a@])" Conflict.pp c);
incr n_conflicts;
Stat.incr self.n_conflicts;
(* TODO: handle conflict, learn a clause or declare unsat *)
(* TODO: see if we want to restart *)
assert false assert false
| None ->
Log.debugf 0 (fun k -> k "TODO: decide");
(* TODO: decide *)
());
(* regularly check if it's time to stop *)
if !n_conflicts mod 64 = 0 then
if should_stop !n_conflicts then (
Log.debugf 1 (fun k -> k "(@[cdsat.stop@ :caused-by-callback@])");
res := Check_res.Unknown Unknown.U_asked_to_stop;
continue := false
)
done;
(* cleanup and exit *)
List.iter (fun f -> f ()) on_exit;
!res
(* plugin actions *) (* plugin actions *)

View file

@ -55,7 +55,7 @@ val tst : t -> Term.store
val vst : t -> TVar.store val vst : t -> TVar.store
val trail : t -> Trail.t val trail : t -> Trail.t
val add_plugin : t -> Plugin.builder -> unit val add_plugin : t -> Plugin.builder -> unit
val iter_plugins : t -> Plugin.t Iter.t val iter_plugins : t -> f:(Plugin.t -> unit) -> unit
val last_res : t -> Check_res.t option val last_res : t -> Check_res.t option
(** Last result. Reset on backtracking/assertion *) (** Last result. Reset on backtracking/assertion *)

View file

@ -31,6 +31,8 @@ let decide (self : t) (v : TVar.t) : Value.t option =
let propagate (self : t) (act : Core.Plugin_action.t) (v : TVar.t) let propagate (self : t) (act : Core.Plugin_action.t) (v : TVar.t)
(value : Value.t) : unit = (value : Value.t) : unit =
Log.debugf 0 (fun k ->
k "(@[bool-plugin.propagate %a@])" (TVar.pp self.vst) v);
() ()
(* TODO: BCP *) (* TODO: BCP *)

View file

@ -40,7 +40,7 @@ let propagate (self : t) _act (v : TVar.t) (value : Value.t) =
| Unin_const _ -> () | Unin_const _ -> ()
| Unin_fun { f = _; args } -> | Unin_fun { f = _; args } ->
(* TODO: update congruence table *) (* TODO: update congruence table *)
Log.debugf 1 (fun k -> k "FIXME: congruence rule"); Log.debugf 0 (fun k -> k "FIXME: congruence rule");
() ()
| _ -> () | _ -> ()

View file

@ -33,7 +33,7 @@ let create ?(stats = Stat.create ()) ~(arg : (module ARG)) tst vst ~proof_tracer
let[@inline] core self = self.core let[@inline] core self = self.core
let add_plugin self p = Core.add_plugin self.core p let add_plugin self p = Core.add_plugin self.core p
let[@inline] iter_plugins self f = Core.iter_plugins self.core f let[@inline] iter_plugins self f = Core.iter_plugins self.core ~f
let[@inline] tst self = self.tst let[@inline] tst self = self.tst
let[@inline] vst self = self.vst let[@inline] vst self = self.vst

View file

@ -1,15 +1,21 @@
module VVec = TVar.Vec_of module VVec = TVar.Vec_of
type t = { vars: VVec.t; levels: Veci.t } type t = { vars: VVec.t; levels: Veci.t; mutable head: int }
let create () : t = { vars = VVec.create (); levels = Veci.create () } let create () : t = { vars = VVec.create (); levels = Veci.create (); head = 0 }
let[@inline] push_assignment (self : t) (v : TVar.t) : unit = let[@inline] push_assignment (self : t) (v : TVar.t) : unit =
VVec.push self.vars v VVec.push self.vars v
let[@inline] var_at (self : t) (i : int) : TVar.t = VVec.get self.vars i let[@inline] get (self : t) (i : int) : TVar.t = VVec.get self.vars i
let[@inline] size (self : t) : int = VVec.size self.vars
let[@inline] n_levels self = Veci.size self.levels let[@inline] n_levels self = Veci.size self.levels
let push_level (self : t) : unit = Veci.push self.levels (VVec.size self.vars)
let[@inline] push_level (self : t) : unit =
Veci.push self.levels (VVec.size self.vars)
let[@inline] head self = self.head
let[@inline] set_head self n = self.head <- n
let pop_levels (self : t) (n : int) ~f : unit = let pop_levels (self : t) (n : int) ~f : unit =
if n <= n_levels self then ( if n <= n_levels self then (
@ -18,5 +24,7 @@ let pop_levels (self : t) (n : int) ~f : unit =
while VVec.size self.vars > idx do while VVec.size self.vars > idx do
let elt = VVec.pop self.vars in let elt = VVec.pop self.vars in
f elt f elt
done done;
(* also reset head *)
if self.head > idx then self.head <- idx
) )

View file

@ -3,7 +3,10 @@
type t type t
val create : unit -> t val create : unit -> t
val var_at : t -> int -> TVar.t val get : t -> int -> TVar.t
val size : t -> int
val push_assignment : t -> TVar.t -> unit val push_assignment : t -> TVar.t -> unit
val head : t -> int
val set_head : t -> int -> unit
include Sidekick_sigs.BACKTRACKABLE0_CB with type t := t and type elt := TVar.t include Sidekick_sigs.BACKTRACKABLE0_CB with type t := t and type elt := TVar.t

View file

@ -126,6 +126,8 @@ let is_bool t =
| E_const { c_view = C_bool; _ } -> true | E_const { c_view = C_bool; _ } -> true
| _ -> false | _ -> false
let[@inline] is_a_bool t = is_bool (ty t)
let is_true t = let is_true t =
match view t with match view t with
| E_const { c_view = C_true; _ } -> true | E_const { c_view = C_true; _ } -> true

View file

@ -24,7 +24,14 @@ val ite : store -> t -> t -> t -> t
(** [ite a b c] is [if a then b else c] *) (** [ite a b c] is [if a then b else c] *)
val is_eq : t -> bool val is_eq : t -> bool
(** [is_eq t] is true if [t] is the [=] constant *)
val is_bool : t -> bool val is_bool : t -> bool
(** [is_bool t] is true if [t] is the type bool itself *)
val is_a_bool : t -> bool
(** [is_a_bool t] is true if [t] has type [bool] *)
val is_true : t -> bool val is_true : t -> bool
val is_false : t -> bool val is_false : t -> bool