mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-05 19:00:33 -05:00
wip(cdsat): start implementing propagation
This commit is contained in:
parent
6f1abedb44
commit
6374fd7d5f
11 changed files with 204 additions and 25 deletions
6
src/cdsat/conflict.ml
Normal file
6
src/cdsat/conflict.ml
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
type t = { vst: TVar.store; lit: TLit.t; propagate_reason: Reason.t }
|
||||
|
||||
let pp out (self : t) =
|
||||
Fmt.fprintf out "(@[conflict %a@])" (TLit.pp self.vst) self.lit
|
||||
|
||||
let make vst ~lit ~propagate_reason () : t = { vst; lit; propagate_reason }
|
||||
7
src/cdsat/conflict.mli
Normal file
7
src/cdsat/conflict.mli
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
(** Conflict discovered during search *)
|
||||
|
||||
type t = { vst: TVar.store; lit: TLit.t; propagate_reason: Reason.t }
|
||||
|
||||
include Sidekick_sigs.PRINT with type t := t
|
||||
|
||||
val make : TVar.store -> lit:TLit.t -> propagate_reason:Reason.t -> unit -> t
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
open Sidekick_core
|
||||
module Proof = Sidekick_proof
|
||||
module Check_res = Sidekick_abstract_solver.Check_res
|
||||
module Unknown = Sidekick_abstract_solver.Unknown
|
||||
|
||||
(** Argument to the solver *)
|
||||
module type ARG = sig
|
||||
|
|
@ -8,6 +9,16 @@ module type ARG = sig
|
|||
(** build a disjunction *)
|
||||
end
|
||||
|
||||
(* TODO: embed a simplifier, and have simplify hooks in plugins.
|
||||
Then use the simplifier on any asserted term *)
|
||||
|
||||
type pending_assignment = {
|
||||
var: TVar.t;
|
||||
value: Value.t;
|
||||
level: int;
|
||||
reason: Reason.t;
|
||||
}
|
||||
|
||||
type t = {
|
||||
tst: Term.store;
|
||||
vst: TVar.store;
|
||||
|
|
@ -16,12 +27,26 @@ type t = {
|
|||
trail: Trail.t;
|
||||
plugins: plugin Vec.t;
|
||||
term_to_var: Term_to_var.t;
|
||||
pending_assignments: pending_assignment Vec.t;
|
||||
mutable last_res: Check_res.t option;
|
||||
proof_tracer: Proof.Tracer.t;
|
||||
n_conflicts: int Stat.counter;
|
||||
n_propagations: int Stat.counter;
|
||||
n_restarts: int Stat.counter;
|
||||
}
|
||||
|
||||
and plugin_action = t
|
||||
|
||||
(* FIXME:
|
||||
- add [on_add_var: TVar.t -> unit]
|
||||
and [on_remove_var: TVar.t -> unit].
|
||||
these are called when a variable becomes relevant/is removed or GC'd
|
||||
(in particular: setup watches + var constraints on add,
|
||||
kill watches and remove constraints on remove)
|
||||
|
||||
- add [gc_mark : TVar.t -> recurse:(TVar.t -> unit) -> unit]
|
||||
to mark sub-variables during GC mark phase.
|
||||
*)
|
||||
and plugin =
|
||||
| P : {
|
||||
st: 'st;
|
||||
|
|
@ -42,13 +67,17 @@ let create ?(stats = Stat.create ()) ~arg tst vst ~proof_tracer () : t =
|
|||
stats;
|
||||
trail = Trail.create ();
|
||||
plugins = Vec.create ();
|
||||
pending_assignments = Vec.create ();
|
||||
term_to_var = Term_to_var.create vst;
|
||||
last_res = None;
|
||||
proof_tracer;
|
||||
n_restarts = Stat.mk_int stats "cdsat.restarts";
|
||||
n_conflicts = Stat.mk_int stats "cdsat.conflicts";
|
||||
n_propagations = Stat.mk_int stats "cdsat.propagations";
|
||||
}
|
||||
|
||||
let[@inline] trail self = self.trail
|
||||
let[@inline] iter_plugins self f = Vec.iter ~f self.plugins
|
||||
let[@inline] iter_plugins self ~f = Vec.iter ~f self.plugins
|
||||
let[@inline] tst self = self.tst
|
||||
let[@inline] vst self = self.vst
|
||||
let[@inline] last_res self = self.last_res
|
||||
|
|
@ -79,10 +108,30 @@ let push_level (self : t) : unit =
|
|||
()
|
||||
|
||||
let pop_levels (self : t) n : unit =
|
||||
let {
|
||||
tst = _;
|
||||
vst = _;
|
||||
arg = _;
|
||||
stats = _;
|
||||
trail;
|
||||
plugins;
|
||||
term_to_var = _;
|
||||
pending_assignments;
|
||||
last_res = _;
|
||||
proof_tracer = _;
|
||||
n_propagations = _;
|
||||
n_conflicts = _;
|
||||
n_restarts = _;
|
||||
} =
|
||||
self
|
||||
in
|
||||
Log.debugf 50 (fun k -> k "(@[cdsat.core.pop-levels %d@])" n);
|
||||
if n > 0 then self.last_res <- None;
|
||||
Trail.pop_levels self.trail n ~f:(fun v -> TVar.unassign self.vst v);
|
||||
Vec.iter self.plugins ~f:(fun (P p) -> p.pop_levels p.st n);
|
||||
if n > 0 then (
|
||||
self.last_res <- None;
|
||||
Vec.clear pending_assignments
|
||||
);
|
||||
Trail.pop_levels trail n ~f:(fun v -> TVar.unassign self.vst v);
|
||||
Vec.iter plugins ~f:(fun (P p) -> p.pop_levels p.st n);
|
||||
()
|
||||
|
||||
(* term to var *)
|
||||
|
|
@ -105,28 +154,123 @@ let add_plugin self (pb : Plugin.builder) : unit =
|
|||
|
||||
let add_ty (_self : t) ~ty:_ : unit = ()
|
||||
|
||||
(* Assign [v <- value] for [reason] at [level].
|
||||
This assignment is delayed. *)
|
||||
let assign (self : t) (v : TVar.t) ~(value : Value.t) ~level:v_level ~reason :
|
||||
unit =
|
||||
Log.debugf 50 (fun k ->
|
||||
k "(@[cdsat.core.assign@ `%a`@ @[<- %a@]@ :reason %a@])"
|
||||
(TVar.pp self.vst) v Value.pp value Reason.pp reason);
|
||||
self.last_res <- None;
|
||||
match TVar.value self.vst v with
|
||||
| None ->
|
||||
TVar.assign self.vst v ~value ~level:v_level ~reason;
|
||||
Trail.push_assignment self.trail v
|
||||
| Some value' when Value.equal value value' -> () (* idempotent *)
|
||||
| Some value' ->
|
||||
(* TODO: conflict *)
|
||||
Log.debugf 0 (fun k -> k "TODO: conflict (incompatible values)");
|
||||
()
|
||||
Vec.push self.pending_assignments { var = v; value; level = v_level; reason }
|
||||
|
||||
exception E_conflict of Conflict.t
|
||||
|
||||
let raise_conflict (c : Conflict.t) : 'a = raise (E_conflict c)
|
||||
|
||||
(* add pending assignments to the trail. This might trigger a conflict
|
||||
in case an assignment contradicts an already existing assignment. *)
|
||||
let perform_pending_assignments (self : t) : unit =
|
||||
while not (Vec.is_empty self.pending_assignments) do
|
||||
let { var = v; level = v_level; value; reason } =
|
||||
Vec.pop_exn self.pending_assignments
|
||||
in
|
||||
match TVar.value self.vst v with
|
||||
| None ->
|
||||
TVar.assign self.vst v ~value ~level:v_level ~reason;
|
||||
Trail.push_assignment self.trail v
|
||||
| Some value' when Value.equal value value' -> () (* idempotent *)
|
||||
| Some _value' ->
|
||||
(* conflict should only occur on booleans since they're the only
|
||||
propagation-able variables *)
|
||||
assert (Term.is_a_bool (TVar.term self.vst v));
|
||||
Log.debugf 0 (fun k ->
|
||||
k "TODO: conflict (incompatible values for %a)" (TVar.pp self.vst) v);
|
||||
raise_conflict
|
||||
@@ Conflict.make self.vst ~lit:(TLit.make true v) ~propagate_reason:reason
|
||||
()
|
||||
done
|
||||
|
||||
let propagate (self : t) : Conflict.t option =
|
||||
let@ () = Profile.with_ "cdsat.propagate" in
|
||||
try
|
||||
let continue = ref true in
|
||||
while !continue do
|
||||
perform_pending_assignments self;
|
||||
|
||||
while Trail.head self.trail < Trail.size self.trail do
|
||||
let var = Trail.get self.trail (Trail.head self.trail) in
|
||||
|
||||
(* TODO: call plugins *)
|
||||
Log.debugf 0 (fun k -> k "TODO: propagate %a" (TVar.pp self.vst) var);
|
||||
|
||||
let value =
|
||||
match TVar.value self.vst var with
|
||||
| Some v -> v
|
||||
| None -> assert false
|
||||
in
|
||||
|
||||
iter_plugins self ~f:(fun (P p) -> p.propagate p.st self var value);
|
||||
|
||||
(* move to next var *)
|
||||
Trail.set_head self.trail (Trail.head self.trail + 1)
|
||||
done;
|
||||
|
||||
(* did we reach fixpoint? *)
|
||||
if Vec.is_empty self.pending_assignments then continue := false
|
||||
done;
|
||||
None
|
||||
with E_conflict c -> Some c
|
||||
|
||||
let solve ~on_exit ~on_progress ~should_stop ~assumptions (self : t) :
|
||||
Check_res.t =
|
||||
let@ () = Profile.with_ "cdsat.solve" in
|
||||
self.last_res <- None;
|
||||
(* TODO: outer loop (propagate; decide)* *)
|
||||
(* TODO: propagation loop, involving plugins *)
|
||||
assert false
|
||||
|
||||
(* FIXME: handle assumptions.
|
||||
- do assumptions first when deciding (forced decisions)
|
||||
- in case of conflict below assumptions len, special conflict analysis to
|
||||
compute unsat core
|
||||
*)
|
||||
|
||||
(* control if loop stops *)
|
||||
let continue = ref true in
|
||||
let n_conflicts = ref 0 in
|
||||
let res = ref (Check_res.Unknown Unknown.U_incomplete) in
|
||||
|
||||
(* main loop *)
|
||||
while !continue do
|
||||
if !n_conflicts mod 64 = 0 then on_progress ();
|
||||
|
||||
(* propagate *)
|
||||
(match propagate self with
|
||||
| Some c ->
|
||||
Log.debugf 1 (fun k ->
|
||||
k "(@[cdsat.propagate.found-conflict@ %a@])" Conflict.pp c);
|
||||
incr n_conflicts;
|
||||
Stat.incr self.n_conflicts;
|
||||
|
||||
(* TODO: handle conflict, learn a clause or declare unsat *)
|
||||
(* TODO: see if we want to restart *)
|
||||
assert false
|
||||
| None ->
|
||||
Log.debugf 0 (fun k -> k "TODO: decide");
|
||||
(* TODO: decide *)
|
||||
());
|
||||
|
||||
(* regularly check if it's time to stop *)
|
||||
if !n_conflicts mod 64 = 0 then
|
||||
if should_stop !n_conflicts then (
|
||||
Log.debugf 1 (fun k -> k "(@[cdsat.stop@ :caused-by-callback@])");
|
||||
|
||||
res := Check_res.Unknown Unknown.U_asked_to_stop;
|
||||
continue := false
|
||||
)
|
||||
done;
|
||||
|
||||
(* cleanup and exit *)
|
||||
List.iter (fun f -> f ()) on_exit;
|
||||
!res
|
||||
|
||||
(* plugin actions *)
|
||||
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ val tst : t -> Term.store
|
|||
val vst : t -> TVar.store
|
||||
val trail : t -> Trail.t
|
||||
val add_plugin : t -> Plugin.builder -> unit
|
||||
val iter_plugins : t -> Plugin.t Iter.t
|
||||
val iter_plugins : t -> f:(Plugin.t -> unit) -> unit
|
||||
|
||||
val last_res : t -> Check_res.t option
|
||||
(** Last result. Reset on backtracking/assertion *)
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ let decide (self : t) (v : TVar.t) : Value.t option =
|
|||
|
||||
let propagate (self : t) (act : Core.Plugin_action.t) (v : TVar.t)
|
||||
(value : Value.t) : unit =
|
||||
Log.debugf 0 (fun k ->
|
||||
k "(@[bool-plugin.propagate %a@])" (TVar.pp self.vst) v);
|
||||
()
|
||||
(* TODO: BCP *)
|
||||
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ let propagate (self : t) _act (v : TVar.t) (value : Value.t) =
|
|||
| Unin_const _ -> ()
|
||||
| Unin_fun { f = _; args } ->
|
||||
(* TODO: update congruence table *)
|
||||
Log.debugf 1 (fun k -> k "FIXME: congruence rule");
|
||||
Log.debugf 0 (fun k -> k "FIXME: congruence rule");
|
||||
()
|
||||
| _ -> ()
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ let create ?(stats = Stat.create ()) ~(arg : (module ARG)) tst vst ~proof_tracer
|
|||
|
||||
let[@inline] core self = self.core
|
||||
let add_plugin self p = Core.add_plugin self.core p
|
||||
let[@inline] iter_plugins self f = Core.iter_plugins self.core f
|
||||
let[@inline] iter_plugins self f = Core.iter_plugins self.core ~f
|
||||
let[@inline] tst self = self.tst
|
||||
let[@inline] vst self = self.vst
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,21 @@
|
|||
module VVec = TVar.Vec_of
|
||||
|
||||
type t = { vars: VVec.t; levels: Veci.t }
|
||||
type t = { vars: VVec.t; levels: Veci.t; mutable head: int }
|
||||
|
||||
let create () : t = { vars = VVec.create (); levels = Veci.create () }
|
||||
let create () : t = { vars = VVec.create (); levels = Veci.create (); head = 0 }
|
||||
|
||||
let[@inline] push_assignment (self : t) (v : TVar.t) : unit =
|
||||
VVec.push self.vars v
|
||||
|
||||
let[@inline] var_at (self : t) (i : int) : TVar.t = VVec.get self.vars i
|
||||
let[@inline] get (self : t) (i : int) : TVar.t = VVec.get self.vars i
|
||||
let[@inline] size (self : t) : int = VVec.size self.vars
|
||||
let[@inline] n_levels self = Veci.size self.levels
|
||||
let push_level (self : t) : unit = Veci.push self.levels (VVec.size self.vars)
|
||||
|
||||
let[@inline] push_level (self : t) : unit =
|
||||
Veci.push self.levels (VVec.size self.vars)
|
||||
|
||||
let[@inline] head self = self.head
|
||||
let[@inline] set_head self n = self.head <- n
|
||||
|
||||
let pop_levels (self : t) (n : int) ~f : unit =
|
||||
if n <= n_levels self then (
|
||||
|
|
@ -18,5 +24,7 @@ let pop_levels (self : t) (n : int) ~f : unit =
|
|||
while VVec.size self.vars > idx do
|
||||
let elt = VVec.pop self.vars in
|
||||
f elt
|
||||
done
|
||||
done;
|
||||
(* also reset head *)
|
||||
if self.head > idx then self.head <- idx
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,10 @@
|
|||
type t
|
||||
|
||||
val create : unit -> t
|
||||
val var_at : t -> int -> TVar.t
|
||||
val get : t -> int -> TVar.t
|
||||
val size : t -> int
|
||||
val push_assignment : t -> TVar.t -> unit
|
||||
val head : t -> int
|
||||
val set_head : t -> int -> unit
|
||||
|
||||
include Sidekick_sigs.BACKTRACKABLE0_CB with type t := t and type elt := TVar.t
|
||||
|
|
|
|||
|
|
@ -126,6 +126,8 @@ let is_bool t =
|
|||
| E_const { c_view = C_bool; _ } -> true
|
||||
| _ -> false
|
||||
|
||||
let[@inline] is_a_bool t = is_bool (ty t)
|
||||
|
||||
let is_true t =
|
||||
match view t with
|
||||
| E_const { c_view = C_true; _ } -> true
|
||||
|
|
|
|||
|
|
@ -24,7 +24,14 @@ val ite : store -> t -> t -> t -> t
|
|||
(** [ite a b c] is [if a then b else c] *)
|
||||
|
||||
val is_eq : t -> bool
|
||||
(** [is_eq t] is true if [t] is the [=] constant *)
|
||||
|
||||
val is_bool : t -> bool
|
||||
(** [is_bool t] is true if [t] is the type bool itself *)
|
||||
|
||||
val is_a_bool : t -> bool
|
||||
(** [is_a_bool t] is true if [t] has type [bool] *)
|
||||
|
||||
val is_true : t -> bool
|
||||
val is_false : t -> bool
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue