mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-06 03:05:31 -05:00
wip(cdsat): start implementing propagation
This commit is contained in:
parent
6f1abedb44
commit
6374fd7d5f
11 changed files with 204 additions and 25 deletions
6
src/cdsat/conflict.ml
Normal file
6
src/cdsat/conflict.ml
Normal file
|
|
@ -0,0 +1,6 @@
|
||||||
|
type t = { vst: TVar.store; lit: TLit.t; propagate_reason: Reason.t }
|
||||||
|
|
||||||
|
let pp out (self : t) =
|
||||||
|
Fmt.fprintf out "(@[conflict %a@])" (TLit.pp self.vst) self.lit
|
||||||
|
|
||||||
|
let make vst ~lit ~propagate_reason () : t = { vst; lit; propagate_reason }
|
||||||
7
src/cdsat/conflict.mli
Normal file
7
src/cdsat/conflict.mli
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
(** Conflict discovered during search *)
|
||||||
|
|
||||||
|
type t = { vst: TVar.store; lit: TLit.t; propagate_reason: Reason.t }
|
||||||
|
|
||||||
|
include Sidekick_sigs.PRINT with type t := t
|
||||||
|
|
||||||
|
val make : TVar.store -> lit:TLit.t -> propagate_reason:Reason.t -> unit -> t
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
open Sidekick_core
|
open Sidekick_core
|
||||||
module Proof = Sidekick_proof
|
module Proof = Sidekick_proof
|
||||||
module Check_res = Sidekick_abstract_solver.Check_res
|
module Check_res = Sidekick_abstract_solver.Check_res
|
||||||
|
module Unknown = Sidekick_abstract_solver.Unknown
|
||||||
|
|
||||||
(** Argument to the solver *)
|
(** Argument to the solver *)
|
||||||
module type ARG = sig
|
module type ARG = sig
|
||||||
|
|
@ -8,6 +9,16 @@ module type ARG = sig
|
||||||
(** build a disjunction *)
|
(** build a disjunction *)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
(* TODO: embed a simplifier, and have simplify hooks in plugins.
|
||||||
|
Then use the simplifier on any asserted term *)
|
||||||
|
|
||||||
|
type pending_assignment = {
|
||||||
|
var: TVar.t;
|
||||||
|
value: Value.t;
|
||||||
|
level: int;
|
||||||
|
reason: Reason.t;
|
||||||
|
}
|
||||||
|
|
||||||
type t = {
|
type t = {
|
||||||
tst: Term.store;
|
tst: Term.store;
|
||||||
vst: TVar.store;
|
vst: TVar.store;
|
||||||
|
|
@ -16,12 +27,26 @@ type t = {
|
||||||
trail: Trail.t;
|
trail: Trail.t;
|
||||||
plugins: plugin Vec.t;
|
plugins: plugin Vec.t;
|
||||||
term_to_var: Term_to_var.t;
|
term_to_var: Term_to_var.t;
|
||||||
|
pending_assignments: pending_assignment Vec.t;
|
||||||
mutable last_res: Check_res.t option;
|
mutable last_res: Check_res.t option;
|
||||||
proof_tracer: Proof.Tracer.t;
|
proof_tracer: Proof.Tracer.t;
|
||||||
|
n_conflicts: int Stat.counter;
|
||||||
|
n_propagations: int Stat.counter;
|
||||||
|
n_restarts: int Stat.counter;
|
||||||
}
|
}
|
||||||
|
|
||||||
and plugin_action = t
|
and plugin_action = t
|
||||||
|
|
||||||
|
(* FIXME:
|
||||||
|
- add [on_add_var: TVar.t -> unit]
|
||||||
|
and [on_remove_var: TVar.t -> unit].
|
||||||
|
these are called when a variable becomes relevant/is removed or GC'd
|
||||||
|
(in particular: setup watches + var constraints on add,
|
||||||
|
kill watches and remove constraints on remove)
|
||||||
|
|
||||||
|
- add [gc_mark : TVar.t -> recurse:(TVar.t -> unit) -> unit]
|
||||||
|
to mark sub-variables during GC mark phase.
|
||||||
|
*)
|
||||||
and plugin =
|
and plugin =
|
||||||
| P : {
|
| P : {
|
||||||
st: 'st;
|
st: 'st;
|
||||||
|
|
@ -42,13 +67,17 @@ let create ?(stats = Stat.create ()) ~arg tst vst ~proof_tracer () : t =
|
||||||
stats;
|
stats;
|
||||||
trail = Trail.create ();
|
trail = Trail.create ();
|
||||||
plugins = Vec.create ();
|
plugins = Vec.create ();
|
||||||
|
pending_assignments = Vec.create ();
|
||||||
term_to_var = Term_to_var.create vst;
|
term_to_var = Term_to_var.create vst;
|
||||||
last_res = None;
|
last_res = None;
|
||||||
proof_tracer;
|
proof_tracer;
|
||||||
|
n_restarts = Stat.mk_int stats "cdsat.restarts";
|
||||||
|
n_conflicts = Stat.mk_int stats "cdsat.conflicts";
|
||||||
|
n_propagations = Stat.mk_int stats "cdsat.propagations";
|
||||||
}
|
}
|
||||||
|
|
||||||
let[@inline] trail self = self.trail
|
let[@inline] trail self = self.trail
|
||||||
let[@inline] iter_plugins self f = Vec.iter ~f self.plugins
|
let[@inline] iter_plugins self ~f = Vec.iter ~f self.plugins
|
||||||
let[@inline] tst self = self.tst
|
let[@inline] tst self = self.tst
|
||||||
let[@inline] vst self = self.vst
|
let[@inline] vst self = self.vst
|
||||||
let[@inline] last_res self = self.last_res
|
let[@inline] last_res self = self.last_res
|
||||||
|
|
@ -79,10 +108,30 @@ let push_level (self : t) : unit =
|
||||||
()
|
()
|
||||||
|
|
||||||
let pop_levels (self : t) n : unit =
|
let pop_levels (self : t) n : unit =
|
||||||
|
let {
|
||||||
|
tst = _;
|
||||||
|
vst = _;
|
||||||
|
arg = _;
|
||||||
|
stats = _;
|
||||||
|
trail;
|
||||||
|
plugins;
|
||||||
|
term_to_var = _;
|
||||||
|
pending_assignments;
|
||||||
|
last_res = _;
|
||||||
|
proof_tracer = _;
|
||||||
|
n_propagations = _;
|
||||||
|
n_conflicts = _;
|
||||||
|
n_restarts = _;
|
||||||
|
} =
|
||||||
|
self
|
||||||
|
in
|
||||||
Log.debugf 50 (fun k -> k "(@[cdsat.core.pop-levels %d@])" n);
|
Log.debugf 50 (fun k -> k "(@[cdsat.core.pop-levels %d@])" n);
|
||||||
if n > 0 then self.last_res <- None;
|
if n > 0 then (
|
||||||
Trail.pop_levels self.trail n ~f:(fun v -> TVar.unassign self.vst v);
|
self.last_res <- None;
|
||||||
Vec.iter self.plugins ~f:(fun (P p) -> p.pop_levels p.st n);
|
Vec.clear pending_assignments
|
||||||
|
);
|
||||||
|
Trail.pop_levels trail n ~f:(fun v -> TVar.unassign self.vst v);
|
||||||
|
Vec.iter plugins ~f:(fun (P p) -> p.pop_levels p.st n);
|
||||||
()
|
()
|
||||||
|
|
||||||
(* term to var *)
|
(* term to var *)
|
||||||
|
|
@ -105,28 +154,123 @@ let add_plugin self (pb : Plugin.builder) : unit =
|
||||||
|
|
||||||
let add_ty (_self : t) ~ty:_ : unit = ()
|
let add_ty (_self : t) ~ty:_ : unit = ()
|
||||||
|
|
||||||
|
(* Assign [v <- value] for [reason] at [level].
|
||||||
|
This assignment is delayed. *)
|
||||||
let assign (self : t) (v : TVar.t) ~(value : Value.t) ~level:v_level ~reason :
|
let assign (self : t) (v : TVar.t) ~(value : Value.t) ~level:v_level ~reason :
|
||||||
unit =
|
unit =
|
||||||
Log.debugf 50 (fun k ->
|
Log.debugf 50 (fun k ->
|
||||||
k "(@[cdsat.core.assign@ `%a`@ @[<- %a@]@ :reason %a@])"
|
k "(@[cdsat.core.assign@ `%a`@ @[<- %a@]@ :reason %a@])"
|
||||||
(TVar.pp self.vst) v Value.pp value Reason.pp reason);
|
(TVar.pp self.vst) v Value.pp value Reason.pp reason);
|
||||||
self.last_res <- None;
|
self.last_res <- None;
|
||||||
match TVar.value self.vst v with
|
Vec.push self.pending_assignments { var = v; value; level = v_level; reason }
|
||||||
| None ->
|
|
||||||
TVar.assign self.vst v ~value ~level:v_level ~reason;
|
exception E_conflict of Conflict.t
|
||||||
Trail.push_assignment self.trail v
|
|
||||||
| Some value' when Value.equal value value' -> () (* idempotent *)
|
let raise_conflict (c : Conflict.t) : 'a = raise (E_conflict c)
|
||||||
| Some value' ->
|
|
||||||
(* TODO: conflict *)
|
(* add pending assignments to the trail. This might trigger a conflict
|
||||||
Log.debugf 0 (fun k -> k "TODO: conflict (incompatible values)");
|
in case an assignment contradicts an already existing assignment. *)
|
||||||
()
|
let perform_pending_assignments (self : t) : unit =
|
||||||
|
while not (Vec.is_empty self.pending_assignments) do
|
||||||
|
let { var = v; level = v_level; value; reason } =
|
||||||
|
Vec.pop_exn self.pending_assignments
|
||||||
|
in
|
||||||
|
match TVar.value self.vst v with
|
||||||
|
| None ->
|
||||||
|
TVar.assign self.vst v ~value ~level:v_level ~reason;
|
||||||
|
Trail.push_assignment self.trail v
|
||||||
|
| Some value' when Value.equal value value' -> () (* idempotent *)
|
||||||
|
| Some _value' ->
|
||||||
|
(* conflict should only occur on booleans since they're the only
|
||||||
|
propagation-able variables *)
|
||||||
|
assert (Term.is_a_bool (TVar.term self.vst v));
|
||||||
|
Log.debugf 0 (fun k ->
|
||||||
|
k "TODO: conflict (incompatible values for %a)" (TVar.pp self.vst) v);
|
||||||
|
raise_conflict
|
||||||
|
@@ Conflict.make self.vst ~lit:(TLit.make true v) ~propagate_reason:reason
|
||||||
|
()
|
||||||
|
done
|
||||||
|
|
||||||
|
let propagate (self : t) : Conflict.t option =
|
||||||
|
let@ () = Profile.with_ "cdsat.propagate" in
|
||||||
|
try
|
||||||
|
let continue = ref true in
|
||||||
|
while !continue do
|
||||||
|
perform_pending_assignments self;
|
||||||
|
|
||||||
|
while Trail.head self.trail < Trail.size self.trail do
|
||||||
|
let var = Trail.get self.trail (Trail.head self.trail) in
|
||||||
|
|
||||||
|
(* TODO: call plugins *)
|
||||||
|
Log.debugf 0 (fun k -> k "TODO: propagate %a" (TVar.pp self.vst) var);
|
||||||
|
|
||||||
|
let value =
|
||||||
|
match TVar.value self.vst var with
|
||||||
|
| Some v -> v
|
||||||
|
| None -> assert false
|
||||||
|
in
|
||||||
|
|
||||||
|
iter_plugins self ~f:(fun (P p) -> p.propagate p.st self var value);
|
||||||
|
|
||||||
|
(* move to next var *)
|
||||||
|
Trail.set_head self.trail (Trail.head self.trail + 1)
|
||||||
|
done;
|
||||||
|
|
||||||
|
(* did we reach fixpoint? *)
|
||||||
|
if Vec.is_empty self.pending_assignments then continue := false
|
||||||
|
done;
|
||||||
|
None
|
||||||
|
with E_conflict c -> Some c
|
||||||
|
|
||||||
let solve ~on_exit ~on_progress ~should_stop ~assumptions (self : t) :
|
let solve ~on_exit ~on_progress ~should_stop ~assumptions (self : t) :
|
||||||
Check_res.t =
|
Check_res.t =
|
||||||
|
let@ () = Profile.with_ "cdsat.solve" in
|
||||||
self.last_res <- None;
|
self.last_res <- None;
|
||||||
(* TODO: outer loop (propagate; decide)* *)
|
|
||||||
(* TODO: propagation loop, involving plugins *)
|
(* FIXME: handle assumptions.
|
||||||
assert false
|
- do assumptions first when deciding (forced decisions)
|
||||||
|
- in case of conflict below assumptions len, special conflict analysis to
|
||||||
|
compute unsat core
|
||||||
|
*)
|
||||||
|
|
||||||
|
(* control if loop stops *)
|
||||||
|
let continue = ref true in
|
||||||
|
let n_conflicts = ref 0 in
|
||||||
|
let res = ref (Check_res.Unknown Unknown.U_incomplete) in
|
||||||
|
|
||||||
|
(* main loop *)
|
||||||
|
while !continue do
|
||||||
|
if !n_conflicts mod 64 = 0 then on_progress ();
|
||||||
|
|
||||||
|
(* propagate *)
|
||||||
|
(match propagate self with
|
||||||
|
| Some c ->
|
||||||
|
Log.debugf 1 (fun k ->
|
||||||
|
k "(@[cdsat.propagate.found-conflict@ %a@])" Conflict.pp c);
|
||||||
|
incr n_conflicts;
|
||||||
|
Stat.incr self.n_conflicts;
|
||||||
|
|
||||||
|
(* TODO: handle conflict, learn a clause or declare unsat *)
|
||||||
|
(* TODO: see if we want to restart *)
|
||||||
|
assert false
|
||||||
|
| None ->
|
||||||
|
Log.debugf 0 (fun k -> k "TODO: decide");
|
||||||
|
(* TODO: decide *)
|
||||||
|
());
|
||||||
|
|
||||||
|
(* regularly check if it's time to stop *)
|
||||||
|
if !n_conflicts mod 64 = 0 then
|
||||||
|
if should_stop !n_conflicts then (
|
||||||
|
Log.debugf 1 (fun k -> k "(@[cdsat.stop@ :caused-by-callback@])");
|
||||||
|
|
||||||
|
res := Check_res.Unknown Unknown.U_asked_to_stop;
|
||||||
|
continue := false
|
||||||
|
)
|
||||||
|
done;
|
||||||
|
|
||||||
|
(* cleanup and exit *)
|
||||||
|
List.iter (fun f -> f ()) on_exit;
|
||||||
|
!res
|
||||||
|
|
||||||
(* plugin actions *)
|
(* plugin actions *)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ val tst : t -> Term.store
|
||||||
val vst : t -> TVar.store
|
val vst : t -> TVar.store
|
||||||
val trail : t -> Trail.t
|
val trail : t -> Trail.t
|
||||||
val add_plugin : t -> Plugin.builder -> unit
|
val add_plugin : t -> Plugin.builder -> unit
|
||||||
val iter_plugins : t -> Plugin.t Iter.t
|
val iter_plugins : t -> f:(Plugin.t -> unit) -> unit
|
||||||
|
|
||||||
val last_res : t -> Check_res.t option
|
val last_res : t -> Check_res.t option
|
||||||
(** Last result. Reset on backtracking/assertion *)
|
(** Last result. Reset on backtracking/assertion *)
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,8 @@ let decide (self : t) (v : TVar.t) : Value.t option =
|
||||||
|
|
||||||
let propagate (self : t) (act : Core.Plugin_action.t) (v : TVar.t)
|
let propagate (self : t) (act : Core.Plugin_action.t) (v : TVar.t)
|
||||||
(value : Value.t) : unit =
|
(value : Value.t) : unit =
|
||||||
|
Log.debugf 0 (fun k ->
|
||||||
|
k "(@[bool-plugin.propagate %a@])" (TVar.pp self.vst) v);
|
||||||
()
|
()
|
||||||
(* TODO: BCP *)
|
(* TODO: BCP *)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ let propagate (self : t) _act (v : TVar.t) (value : Value.t) =
|
||||||
| Unin_const _ -> ()
|
| Unin_const _ -> ()
|
||||||
| Unin_fun { f = _; args } ->
|
| Unin_fun { f = _; args } ->
|
||||||
(* TODO: update congruence table *)
|
(* TODO: update congruence table *)
|
||||||
Log.debugf 1 (fun k -> k "FIXME: congruence rule");
|
Log.debugf 0 (fun k -> k "FIXME: congruence rule");
|
||||||
()
|
()
|
||||||
| _ -> ()
|
| _ -> ()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ let create ?(stats = Stat.create ()) ~(arg : (module ARG)) tst vst ~proof_tracer
|
||||||
|
|
||||||
let[@inline] core self = self.core
|
let[@inline] core self = self.core
|
||||||
let add_plugin self p = Core.add_plugin self.core p
|
let add_plugin self p = Core.add_plugin self.core p
|
||||||
let[@inline] iter_plugins self f = Core.iter_plugins self.core f
|
let[@inline] iter_plugins self f = Core.iter_plugins self.core ~f
|
||||||
let[@inline] tst self = self.tst
|
let[@inline] tst self = self.tst
|
||||||
let[@inline] vst self = self.vst
|
let[@inline] vst self = self.vst
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,21 @@
|
||||||
module VVec = TVar.Vec_of
|
module VVec = TVar.Vec_of
|
||||||
|
|
||||||
type t = { vars: VVec.t; levels: Veci.t }
|
type t = { vars: VVec.t; levels: Veci.t; mutable head: int }
|
||||||
|
|
||||||
let create () : t = { vars = VVec.create (); levels = Veci.create () }
|
let create () : t = { vars = VVec.create (); levels = Veci.create (); head = 0 }
|
||||||
|
|
||||||
let[@inline] push_assignment (self : t) (v : TVar.t) : unit =
|
let[@inline] push_assignment (self : t) (v : TVar.t) : unit =
|
||||||
VVec.push self.vars v
|
VVec.push self.vars v
|
||||||
|
|
||||||
let[@inline] var_at (self : t) (i : int) : TVar.t = VVec.get self.vars i
|
let[@inline] get (self : t) (i : int) : TVar.t = VVec.get self.vars i
|
||||||
|
let[@inline] size (self : t) : int = VVec.size self.vars
|
||||||
let[@inline] n_levels self = Veci.size self.levels
|
let[@inline] n_levels self = Veci.size self.levels
|
||||||
let push_level (self : t) : unit = Veci.push self.levels (VVec.size self.vars)
|
|
||||||
|
let[@inline] push_level (self : t) : unit =
|
||||||
|
Veci.push self.levels (VVec.size self.vars)
|
||||||
|
|
||||||
|
let[@inline] head self = self.head
|
||||||
|
let[@inline] set_head self n = self.head <- n
|
||||||
|
|
||||||
let pop_levels (self : t) (n : int) ~f : unit =
|
let pop_levels (self : t) (n : int) ~f : unit =
|
||||||
if n <= n_levels self then (
|
if n <= n_levels self then (
|
||||||
|
|
@ -18,5 +24,7 @@ let pop_levels (self : t) (n : int) ~f : unit =
|
||||||
while VVec.size self.vars > idx do
|
while VVec.size self.vars > idx do
|
||||||
let elt = VVec.pop self.vars in
|
let elt = VVec.pop self.vars in
|
||||||
f elt
|
f elt
|
||||||
done
|
done;
|
||||||
|
(* also reset head *)
|
||||||
|
if self.head > idx then self.head <- idx
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,10 @@
|
||||||
type t
|
type t
|
||||||
|
|
||||||
val create : unit -> t
|
val create : unit -> t
|
||||||
val var_at : t -> int -> TVar.t
|
val get : t -> int -> TVar.t
|
||||||
|
val size : t -> int
|
||||||
val push_assignment : t -> TVar.t -> unit
|
val push_assignment : t -> TVar.t -> unit
|
||||||
|
val head : t -> int
|
||||||
|
val set_head : t -> int -> unit
|
||||||
|
|
||||||
include Sidekick_sigs.BACKTRACKABLE0_CB with type t := t and type elt := TVar.t
|
include Sidekick_sigs.BACKTRACKABLE0_CB with type t := t and type elt := TVar.t
|
||||||
|
|
|
||||||
|
|
@ -126,6 +126,8 @@ let is_bool t =
|
||||||
| E_const { c_view = C_bool; _ } -> true
|
| E_const { c_view = C_bool; _ } -> true
|
||||||
| _ -> false
|
| _ -> false
|
||||||
|
|
||||||
|
let[@inline] is_a_bool t = is_bool (ty t)
|
||||||
|
|
||||||
let is_true t =
|
let is_true t =
|
||||||
match view t with
|
match view t with
|
||||||
| E_const { c_view = C_true; _ } -> true
|
| E_const { c_view = C_true; _ } -> true
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,14 @@ val ite : store -> t -> t -> t -> t
|
||||||
(** [ite a b c] is [if a then b else c] *)
|
(** [ite a b c] is [if a then b else c] *)
|
||||||
|
|
||||||
val is_eq : t -> bool
|
val is_eq : t -> bool
|
||||||
|
(** [is_eq t] is true if [t] is the [=] constant *)
|
||||||
|
|
||||||
val is_bool : t -> bool
|
val is_bool : t -> bool
|
||||||
|
(** [is_bool t] is true if [t] is the type bool itself *)
|
||||||
|
|
||||||
|
val is_a_bool : t -> bool
|
||||||
|
(** [is_a_bool t] is true if [t] has type [bool] *)
|
||||||
|
|
||||||
val is_true : t -> bool
|
val is_true : t -> bool
|
||||||
val is_false : t -> bool
|
val is_false : t -> bool
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue