mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-06 03:05:31 -05:00
Mcsat now works (for pure equality problems)
This commit is contained in:
parent
4f5bb640ca
commit
9cf13bd7a2
15 changed files with 655 additions and 46 deletions
2
Makefile
2
Makefile
|
|
@ -39,7 +39,7 @@ test: bin test_bin
|
|||
@echo "run API tests…"
|
||||
@./test_api.native
|
||||
@echo "run benchmarks…"
|
||||
@/usr/bin/time -f "%e" ./tests/run smt
|
||||
# @/usr/bin/time -f "%e" ./tests/run smt
|
||||
@/usr/bin/time -f "%e" ./tests/run mcsat
|
||||
|
||||
enable_log:
|
||||
|
|
|
|||
|
|
@ -11,39 +11,44 @@ type negated = Formula_intf.negated =
|
|||
module type S = sig
|
||||
(** Signature of formulas that parametrises the Mcsat Solver Module. *)
|
||||
|
||||
module Term : sig
|
||||
(** The type of terms *)
|
||||
type t
|
||||
val hash : t -> int
|
||||
val equal : t -> t -> bool
|
||||
val print : Format.formatter -> t -> unit
|
||||
end
|
||||
|
||||
module Formula : sig
|
||||
(** The type of atomic formulas over terms. *)
|
||||
type t
|
||||
val hash : t -> int
|
||||
val equal : t -> t -> bool
|
||||
val print : Format.formatter -> t -> unit
|
||||
end
|
||||
|
||||
type proof
|
||||
(** An abstract type for proofs *)
|
||||
|
||||
val dummy : Formula.t
|
||||
module Term : sig
|
||||
|
||||
type t
|
||||
(** The type of terms *)
|
||||
|
||||
val hash : t -> int
|
||||
val equal : t -> t -> bool
|
||||
val print : Format.formatter -> t -> unit
|
||||
(** Common functions *)
|
||||
|
||||
end
|
||||
|
||||
module Formula : sig
|
||||
|
||||
type t
|
||||
(** The type of atomic formulas over terms. *)
|
||||
|
||||
val hash : t -> int
|
||||
val equal : t -> t -> bool
|
||||
val print : Format.formatter -> t -> unit
|
||||
(** Common functions *)
|
||||
|
||||
val dummy : t
|
||||
(** Formula constants. A valid formula should never be physically equal to [dummy] *)
|
||||
|
||||
val fresh : unit -> Formula.t
|
||||
(** Returns a fresh litteral, distinct from any other literal (used in cnf conversion) *)
|
||||
|
||||
val neg : Formula.t -> Formula.t
|
||||
val neg : t -> t
|
||||
(** Formula negation *)
|
||||
|
||||
val norm : Formula.t -> Formula.t * negated
|
||||
val norm : t -> t * negated
|
||||
(** Returns a 'normalized' form of the formula, possibly negated
|
||||
(in which case return [Negated]).
|
||||
[norm] must be so that [a] and [neg a] normalise to the same formula,
|
||||
but one returns [Negated] and the other [Same_sign]. *)
|
||||
end
|
||||
|
||||
|
||||
end
|
||||
|
||||
|
|
|
|||
|
|
@ -17,12 +17,14 @@ module type S = sig
|
|||
type proof
|
||||
(** An abstract type for proofs *)
|
||||
|
||||
val hash : t -> int
|
||||
val equal : t -> t -> bool
|
||||
val print : Format.formatter -> t -> unit
|
||||
(** Common functions *)
|
||||
|
||||
val dummy : t
|
||||
(** Formula constants. A valid formula should never be physically equal to [dummy] *)
|
||||
|
||||
val fresh : unit -> t
|
||||
(** Returns a fresh literal, distinct from any other literal (used in cnf conversion) *)
|
||||
|
||||
val neg : t -> t
|
||||
(** Formula negation *)
|
||||
|
||||
|
|
@ -32,11 +34,5 @@ module type S = sig
|
|||
[norm] must be so that [a] and [neg a] normalise to the same formula,
|
||||
but one returns [Same_sign] and one returns [Negated] *)
|
||||
|
||||
val hash : t -> int
|
||||
val equal : t -> t -> bool
|
||||
(** Usual hash and comparison functions. Given to Hashtbl functors. *)
|
||||
|
||||
val print : Format.formatter -> t -> unit
|
||||
(** Printing function used for debugging. *)
|
||||
end
|
||||
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ module McMake (E : Expr_intf.S)(Dummy : sig end) = struct
|
|||
| E_var of var
|
||||
|
||||
(* Dummy values *)
|
||||
let dummy_lit = E.dummy
|
||||
let dummy_lit = E.Formula.dummy
|
||||
|
||||
let rec dummy_var =
|
||||
{ vid = -101;
|
||||
|
|
@ -144,7 +144,7 @@ module McMake (E : Expr_intf.S)(Dummy : sig end) = struct
|
|||
|
||||
let make_boolean_var : formula -> var * Expr_intf.negated =
|
||||
fun t ->
|
||||
let lit, negated = E.norm t in
|
||||
let lit, negated = E.Formula.norm t in
|
||||
try
|
||||
MF.find f_map lit, negated
|
||||
with Not_found ->
|
||||
|
|
@ -168,7 +168,7 @@ module McMake (E : Expr_intf.S)(Dummy : sig end) = struct
|
|||
aid = cpt_fois_2 (* aid = vid*2 *) }
|
||||
and na =
|
||||
{ var = var;
|
||||
lit = E.neg lit;
|
||||
lit = E.Formula.neg lit;
|
||||
watched = Vec.make 10 dummy_clause;
|
||||
neg = pa;
|
||||
is_true = false;
|
||||
|
|
|
|||
231
src/mcsat/eclosure.ml
Normal file
231
src/mcsat/eclosure.ml
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
|
||||
module type Key = sig
|
||||
type t
|
||||
val hash : t -> int
|
||||
val equal : t -> t -> bool
|
||||
val compare : t -> t -> int
|
||||
val print : Format.formatter -> t -> unit
|
||||
end
|
||||
|
||||
module type S = sig
|
||||
type t
|
||||
type var
|
||||
|
||||
exception Unsat of var * var * var list
|
||||
|
||||
val create : Backtrack.Stack.t -> t
|
||||
|
||||
val find : t -> var -> var
|
||||
|
||||
val add_eq : t -> var -> var -> unit
|
||||
val add_neq : t -> var -> var -> unit
|
||||
val add_tag : t -> var -> var -> unit
|
||||
|
||||
val find_tag : t -> var -> var * (var * var) option
|
||||
|
||||
end
|
||||
|
||||
module Make(T : Key) = struct
|
||||
|
||||
module M = Map.Make(T)
|
||||
module H = Backtrack.Hashtbl(T)
|
||||
|
||||
type var = T.t
|
||||
|
||||
exception Equal of var * var
|
||||
exception Same_tag of var * var
|
||||
exception Unsat of var * var * var list
|
||||
|
||||
type repr_info = {
|
||||
rank : int;
|
||||
tag : (T.t * T.t) option;
|
||||
forbidden : (var * var) M.t;
|
||||
}
|
||||
|
||||
type node =
|
||||
| Follow of var
|
||||
| Repr of repr_info
|
||||
|
||||
type t = {
|
||||
size : int H.t;
|
||||
expl : var H.t;
|
||||
repr : node H.t;
|
||||
}
|
||||
|
||||
let create s = {
|
||||
size = H.create s;
|
||||
expl = H.create s;
|
||||
repr = H.create s;
|
||||
}
|
||||
|
||||
(* Union-find algorithm with path compression *)
|
||||
let self_repr = Repr { rank = 0; tag = None; forbidden = M.empty }
|
||||
|
||||
let find_hash m i default =
|
||||
try H.find m i
|
||||
with Not_found -> default
|
||||
|
||||
let rec find_aux m i =
|
||||
match find_hash m i self_repr with
|
||||
| Repr r -> r, i
|
||||
| Follow j ->
|
||||
let r, k = find_aux m j in
|
||||
H.add m i (Follow k);
|
||||
r, k
|
||||
|
||||
let get_repr h x =
|
||||
let r, y = find_aux h.repr x in
|
||||
y, r
|
||||
|
||||
let tag h x v =
|
||||
let r, y = find_aux h.repr x in
|
||||
let new_m =
|
||||
{ r with
|
||||
tag = match r.tag with
|
||||
| Some (_, v') when not (T.equal v v') -> raise (Equal (x, y))
|
||||
| (Some _) as t -> t
|
||||
| None -> Some (x, v) }
|
||||
in
|
||||
H.add h.repr y (Repr new_m)
|
||||
|
||||
let find h x = fst (get_repr h x)
|
||||
|
||||
let find_tag h x =
|
||||
let r, y = find_aux h.repr x in
|
||||
y, r.tag
|
||||
|
||||
let forbid_aux m x =
|
||||
try
|
||||
let a, b = M.find x m in
|
||||
raise (Equal (a, b))
|
||||
with Not_found -> ()
|
||||
|
||||
let link h x mx y my =
|
||||
let new_m = {
|
||||
rank = if mx.rank = my.rank then mx.rank + 1 else mx.rank;
|
||||
tag = (match mx.tag, my.tag with
|
||||
| Some (z, t1), Some (w, t2) ->
|
||||
if not (T.equal t1 t2) then begin
|
||||
Log.debugf 3 "Tag shenanigan : %a (%a) <> %a (%a)" (fun k ->
|
||||
k T.print t1 T.print z T.print t2 T.print w);
|
||||
raise (Equal (z, w))
|
||||
end else Some (z, t1)
|
||||
| Some t, None | None, Some t -> Some t
|
||||
| None, None -> None);
|
||||
forbidden = M.merge (fun _ b1 b2 -> match b1, b2 with
|
||||
| Some r, _ | None, Some r -> Some r | _ -> assert false)
|
||||
mx.forbidden my.forbidden;}
|
||||
in
|
||||
let aux m z eq =
|
||||
match H.find m z with
|
||||
| Repr r ->
|
||||
let r' = { r with
|
||||
forbidden = M.add x eq (M.remove y r.forbidden) }
|
||||
in
|
||||
H.add m z (Repr r')
|
||||
| _ -> assert false
|
||||
in
|
||||
M.iter (aux h.repr) my.forbidden;
|
||||
H.add h.repr y (Follow x);
|
||||
H.add h.repr x (Repr new_m)
|
||||
|
||||
let union h x y =
|
||||
let rx, mx = get_repr h x in
|
||||
let ry, my = get_repr h y in
|
||||
if T.compare rx ry <> 0 then begin
|
||||
forbid_aux mx.forbidden ry;
|
||||
forbid_aux my.forbidden rx;
|
||||
if mx.rank > my.rank then begin
|
||||
link h rx mx ry my
|
||||
end else begin
|
||||
link h ry my rx mx
|
||||
end
|
||||
end
|
||||
|
||||
let forbid h x y =
|
||||
let rx, mx = get_repr h x in
|
||||
let ry, my = get_repr h y in
|
||||
if T.compare rx ry = 0 then
|
||||
raise (Equal (x, y))
|
||||
else match mx.tag, my.tag with
|
||||
| Some (a, v), Some (b, v') when T.compare v v' = 0 ->
|
||||
raise (Same_tag(a, b))
|
||||
| _ ->
|
||||
H.add h.repr ry (Repr { my with forbidden = M.add rx (x, y) my.forbidden });
|
||||
H.add h.repr rx (Repr { mx with forbidden = M.add ry (x, y) mx.forbidden })
|
||||
|
||||
(* Equivalence closure with explanation output *)
|
||||
let find_parent v m = find_hash m v v
|
||||
|
||||
let rec root m acc curr =
|
||||
let parent = find_parent curr m in
|
||||
if T.compare curr parent = 0 then
|
||||
curr :: acc
|
||||
else
|
||||
root m (curr :: acc) parent
|
||||
|
||||
let rec rev_root m curr =
|
||||
let next = find_parent curr m in
|
||||
if T.compare curr next = 0 then
|
||||
curr
|
||||
else begin
|
||||
H.remove m curr;
|
||||
let res = rev_root m next in
|
||||
H.add m next curr;
|
||||
res
|
||||
end
|
||||
|
||||
let expl t a b =
|
||||
let rec aux last = function
|
||||
| x :: r, y :: r' when T.compare x y = 0 ->
|
||||
aux (Some x) (r, r')
|
||||
| l, l' -> begin match last with
|
||||
| Some z -> List.rev_append (z :: l) l'
|
||||
| None -> List.rev_append l l'
|
||||
end
|
||||
in
|
||||
aux None (root t.expl [] a, root t.expl [] b)
|
||||
|
||||
let add_eq_aux t i j =
|
||||
if T.compare (find t i) (find t j) = 0 then
|
||||
()
|
||||
else begin
|
||||
let old_root_i = rev_root t.expl i in
|
||||
let old_root_j = rev_root t.expl j in
|
||||
let nb_i = find_hash t.size old_root_i 0 in
|
||||
let nb_j = find_hash t.size old_root_j 0 in
|
||||
if nb_i < nb_j then begin
|
||||
H.add t.expl i j;
|
||||
H.add t.size j (nb_i + nb_j + 1)
|
||||
end else begin
|
||||
H.add t.expl j i;
|
||||
H.add t.size i (nb_i + nb_j + 1)
|
||||
end
|
||||
end
|
||||
|
||||
(* Functions wrapped to produce explanation in case
|
||||
* something went wrong *)
|
||||
let add_tag t x v =
|
||||
match tag t x v with
|
||||
| () -> ()
|
||||
| exception Equal (a, b) ->
|
||||
raise (Unsat (a, b, expl t a b))
|
||||
|
||||
let add_eq t i j =
|
||||
add_eq_aux t i j;
|
||||
match union t i j with
|
||||
| () -> ()
|
||||
| exception Equal (a, b) ->
|
||||
raise (Unsat (a, b, expl t a b))
|
||||
|
||||
let add_neq t i j =
|
||||
match forbid t i j with
|
||||
| () -> ()
|
||||
| exception Equal (a, b) ->
|
||||
raise (Unsat (a, b, expl t a b))
|
||||
| exception Same_tag (x, y) ->
|
||||
add_eq_aux t i j;
|
||||
let res = expl t i j in
|
||||
raise (Unsat (i, j, res))
|
||||
|
||||
end
|
||||
60
src/mcsat/eclosure.mli
Normal file
60
src/mcsat/eclosure.mli
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
|
||||
(** Equality closure using an union-find structure.
|
||||
This module implements a equality closure algorithm using an union-find structure.
|
||||
It supports adding of equality as well as inequalities, and raises exceptions
|
||||
when trying to build an incoherent closure.
|
||||
Please note that this does not implement congruence closure, as we do not
|
||||
look inside the terms we are given. *)
|
||||
|
||||
module type Key = sig
|
||||
(** The type of keys used by the equality closure algorithm *)
|
||||
|
||||
type t
|
||||
val hash : t -> int
|
||||
val equal : t -> t -> bool
|
||||
val compare : t -> t -> int
|
||||
val print : Format.formatter -> t -> unit
|
||||
end
|
||||
|
||||
module type S = sig
|
||||
(** Type signature for the equality closure algorithm *)
|
||||
|
||||
type t
|
||||
(** Mutable state of the equality closure algorithm. *)
|
||||
|
||||
type var
|
||||
(** The type of expressions on which equality closure is built *)
|
||||
|
||||
exception Unsat of var * var * var list
|
||||
(** Raise when trying to build an incoherent equality closure, with an explanation
|
||||
of the incoherence.
|
||||
[Unsat (a, b, l)] is such that:
|
||||
- [a <> b] has been previously added to the closure.
|
||||
- [l] start with [a] and ends with [b]
|
||||
- for each consecutive terms [p] and [q] in [l],
|
||||
an equality [p = q] has been added to the closure.
|
||||
*)
|
||||
|
||||
val create : Backtrack.Stack.t -> t
|
||||
(** Creates an empty state which uses the given backtrack stack *)
|
||||
|
||||
val find : t -> var -> var
|
||||
(** Returns the representative of the given expression in the current closure state *)
|
||||
|
||||
val add_eq : t -> var -> var -> unit
|
||||
val add_neq : t -> var -> var -> unit
|
||||
(** Add an equality of inequality to the closure. *)
|
||||
|
||||
val add_tag : t -> var -> var -> unit
|
||||
(** Add a tag to an expression. The algorithm ensures that each equality class
|
||||
only has one tag. If incoherent tags are added, an exception is raised. *)
|
||||
|
||||
val find_tag : t -> var -> var * (var * var) option
|
||||
(** Returns the tag associated with the equality class of the given term, if any.
|
||||
More specifically, [find_tag e] returns a pair [(repr, o)] where [repr] is the representant of the equality
|
||||
class of [e]. If the class has a tag, then [o = Some (e', t)] such that [e'] has been tagged with [t] previously. *)
|
||||
|
||||
end
|
||||
|
||||
module Make(T : Key) : S with type var = T.t
|
||||
|
||||
|
|
@ -4,8 +4,10 @@ Copyright 2014 Guillaume Bury
|
|||
Copyright 2014 Simon Cruanes
|
||||
*)
|
||||
|
||||
module Th = Solver.DummyTheory(Expr_smt.Atom)
|
||||
|
||||
module Make(Dummy:sig end) =
|
||||
Solver.Make(Expr_smt.Atom)(Th)(struct end)
|
||||
Mcsolver.Make(struct
|
||||
type proof = unit
|
||||
module Term = Expr_smt.Term
|
||||
module Formula = Expr_smt.Atom
|
||||
end)(Theory_mcsat)(struct end)
|
||||
|
||||
|
|
|
|||
118
src/mcsat/theory_mcsat.ml
Normal file
118
src/mcsat/theory_mcsat.ml
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
|
||||
(* Module initialization *)
|
||||
|
||||
module E = Eclosure.Make(Expr_smt.Term)
|
||||
module H = Backtrack.Hashtbl(Expr_smt.Term)
|
||||
|
||||
(* Type definitions *)
|
||||
|
||||
type proof = unit
|
||||
type term = Expr_smt.Term.t
|
||||
type formula = Expr_smt.Atom.t
|
||||
|
||||
type level = Backtrack.Stack.level
|
||||
|
||||
|
||||
(* Backtracking *)
|
||||
|
||||
let stack = Backtrack.Stack.create ()
|
||||
|
||||
let dummy = Backtrack.Stack.dummy_level
|
||||
|
||||
let current_level () = Backtrack.Stack.level stack
|
||||
|
||||
let backtrack = Backtrack.Stack.backtrack stack
|
||||
|
||||
(* Equality closure *)
|
||||
|
||||
let uf = E.create stack
|
||||
|
||||
let assign t =
|
||||
match E.find_tag uf t with
|
||||
| _, None -> t
|
||||
| _, Some (_, v) -> v
|
||||
|
||||
(* Uninterpreted functions and predicates *)
|
||||
|
||||
let map = H.create stack
|
||||
|
||||
let true_ = Expr_smt.(Term.of_id (Id.ty "true" Ty.prop))
|
||||
let false_ = Expr_smt.(Term.of_id (Id.ty "false" Ty.prop))
|
||||
|
||||
let add_assign t v lvl =
|
||||
H.add map t (v, lvl)
|
||||
|
||||
(* Assignemnts *)
|
||||
|
||||
let rec iter_aux f = function
|
||||
| { Expr_smt.term = Expr_smt.Var _ } as t ->
|
||||
f t
|
||||
| { Expr_smt.term = Expr_smt.App (_, _, l) } as t ->
|
||||
List.iter (iter_aux f) l;
|
||||
f t
|
||||
|
||||
let iter_assignable f = function
|
||||
| { Expr_smt.atom = Expr_smt.Pred p } ->
|
||||
iter_aux f p;
|
||||
| { Expr_smt.atom = Expr_smt.Equal (a, b) } ->
|
||||
iter_aux f a; iter_aux f b
|
||||
|
||||
let eval = function
|
||||
| { Expr_smt.atom = Expr_smt.Pred t } ->
|
||||
begin try
|
||||
let v, lvl = H.find map t in
|
||||
if Expr_smt.Term.equal v true_ then
|
||||
Plugin_intf.Valued (true, lvl)
|
||||
else if Expr_smt.Term.equal v false_ then
|
||||
Plugin_intf.Valued (false, lvl)
|
||||
else
|
||||
Plugin_intf.Unknown
|
||||
with Not_found ->
|
||||
Plugin_intf.Unknown
|
||||
end
|
||||
| { Expr_smt.atom = Expr_smt.Equal (a, b); sign } ->
|
||||
begin try
|
||||
let v_a, a_lvl = H.find map a in
|
||||
let v_b, b_lvl = H.find map a in
|
||||
if Expr_smt.Term.equal v_a v_b then
|
||||
Plugin_intf.Valued(sign, max a_lvl b_lvl)
|
||||
else
|
||||
Plugin_intf.Valued(not sign, max a_lvl b_lvl)
|
||||
with Not_found ->
|
||||
Plugin_intf.Unknown
|
||||
end
|
||||
|
||||
|
||||
(* Theory propagation *)
|
||||
|
||||
let if_sat _ = ()
|
||||
|
||||
let rec chain_eq = function
|
||||
| [] | [_] -> []
|
||||
| a :: ((b :: r) as l) -> (Expr_smt.Atom.eq a b) :: chain_eq l
|
||||
|
||||
let assume s =
|
||||
let open Plugin_intf in
|
||||
try
|
||||
for i = s.start to s.start + s.length - 1 do
|
||||
match s.get i with
|
||||
| Assign (t, v, lvl) ->
|
||||
add_assign t v lvl;
|
||||
E.add_tag uf t v
|
||||
| Lit f ->
|
||||
begin match f with
|
||||
| { Expr_smt.atom = Expr_smt.Equal (u, v); sign = true } ->
|
||||
E.add_eq uf u v
|
||||
| { Expr_smt.atom = Expr_smt.Equal (u, v); sign = false } ->
|
||||
E.add_neq uf u v
|
||||
| { Expr_smt.atom = Expr_smt.Pred p; sign } ->
|
||||
let v = if sign then true_ else false_ in
|
||||
add_assign p v ~-1
|
||||
end
|
||||
done;
|
||||
Plugin_intf.Sat
|
||||
with
|
||||
| E.Unsat (a, b, l) ->
|
||||
let c = Expr_smt.Atom.eq a b :: List.map Expr_smt.Atom.neg (chain_eq l) in
|
||||
Plugin_intf.Unsat (c, ())
|
||||
|
||||
|
|
@ -9,3 +9,6 @@ include Formula_intf.S
|
|||
val make : int -> t
|
||||
(** Make a proposition from an integer. *)
|
||||
|
||||
val fresh : unit -> t
|
||||
(** Make a fresh atom *)
|
||||
|
||||
|
|
|
|||
|
|
@ -318,7 +318,8 @@ let rec parse_expr (env : env) t =
|
|||
| _ -> _bad_arity "xor" 2 t
|
||||
end
|
||||
|
||||
| { Ast.term = Ast.App ({Ast.term = Ast.Builtin Ast.Imply}, l) } as t ->
|
||||
| ({ Ast.term = Ast.App ({Ast.term = Ast.Builtin Ast.Imply}, l) } as t)
|
||||
| ({ Ast.term = Ast.App ({Ast.term = Ast.Symbol { Id.name = "=>" }}, l) } as t) ->
|
||||
begin match l with
|
||||
| [p; q] ->
|
||||
let f = parse_formula env p in
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@
|
|||
|
||||
module type S = Tseitin_intf.S
|
||||
|
||||
module Make (F : Formula_intf.S) = struct
|
||||
module Make (F : Tseitin_intf.Arg) = struct
|
||||
|
||||
exception Empty_Or
|
||||
type combinator = And | Or | Imp | Not
|
||||
|
|
|
|||
|
|
@ -6,4 +6,5 @@ Copyright 2014 Simon Cruanes
|
|||
|
||||
module type S = Tseitin_intf.S
|
||||
|
||||
module Make : functor (F : Formula_intf.S) -> S with type atom = F.t
|
||||
module Make : functor
|
||||
(F : Tseitin_intf.Arg) -> S with type atom = F.t
|
||||
|
|
|
|||
|
|
@ -10,6 +10,22 @@
|
|||
(* *)
|
||||
(**************************************************************************)
|
||||
|
||||
module type Arg = sig
|
||||
|
||||
type t
|
||||
(** Type of atomic formulas *)
|
||||
|
||||
val neg : t -> t
|
||||
(** Negation of atomic formulas *)
|
||||
|
||||
val fresh : unit -> t
|
||||
(** Generate fresh formulas *)
|
||||
|
||||
val print : Format.formatter -> t -> unit
|
||||
(** Print the given formula *)
|
||||
|
||||
end
|
||||
|
||||
module type S = sig
|
||||
|
||||
(** The type of ground formulas *)
|
||||
|
|
|
|||
99
src/util/backtrack.ml
Normal file
99
src/util/backtrack.ml
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
|
||||
module Stack = struct
|
||||
|
||||
type op =
|
||||
(* Stack structure *)
|
||||
| Nil : op
|
||||
| Level : op * int -> op
|
||||
(* Undo operations *)
|
||||
| Set : 'a ref * 'a * op -> op
|
||||
| Call1 : ('a -> unit) * 'a * op -> op
|
||||
| Call2 : ('a -> 'b -> unit) * 'a * 'b * op -> op
|
||||
| Call3 : ('a -> 'b -> 'c -> unit) * 'a * 'b * 'c * op -> op
|
||||
| CallUnit : (unit -> unit) * op -> op
|
||||
|
||||
type t = {
|
||||
mutable stack : op;
|
||||
mutable last : int;
|
||||
}
|
||||
|
||||
type level = int
|
||||
|
||||
let dummy_level = -1
|
||||
|
||||
let create () = {
|
||||
stack = Nil;
|
||||
last = dummy_level;
|
||||
}
|
||||
|
||||
let register_set t ref value = t.stack <- Set(ref, value, t.stack)
|
||||
let register_undo t f = t.stack <- CallUnit (f, t.stack)
|
||||
let register1 t f x = t.stack <- Call1 (f, x, t.stack)
|
||||
let register2 t f x y = t.stack <- Call2 (f, x, y, t.stack)
|
||||
let register3 t f x y z = t.stack <- Call3 (f, x, y, z, t.stack)
|
||||
|
||||
let curr = ref 0
|
||||
|
||||
let push t =
|
||||
let level = !curr in
|
||||
t.stack <- Level (t.stack, level);
|
||||
t.last <- level;
|
||||
incr curr
|
||||
|
||||
let rec level t =
|
||||
match t.stack with
|
||||
| Level (_, lvl) -> lvl
|
||||
| _ -> push t; level t
|
||||
|
||||
let backtrack t lvl =
|
||||
let rec pop = function
|
||||
| Nil -> assert false
|
||||
| Level (op, level) as current ->
|
||||
if level = lvl then begin
|
||||
t.stack <- current;
|
||||
t.last <- level
|
||||
end else
|
||||
pop op
|
||||
| Set (ref, x, op) -> ref := x; pop op
|
||||
| CallUnit (f, op) -> f (); pop op
|
||||
| Call1 (f, x, op) -> f x; pop op
|
||||
| Call2 (f, x, y, op) -> f x y; pop op
|
||||
| Call3 (f, x, y, z, op) -> f x y z; pop op
|
||||
in
|
||||
pop t.stack
|
||||
|
||||
let pop t = backtrack t (t.last)
|
||||
|
||||
end
|
||||
|
||||
module Hashtbl(K : Hashtbl.HashedType) = struct
|
||||
|
||||
module H = Hashtbl.Make(K)
|
||||
|
||||
type key = K.t
|
||||
type 'a t = {
|
||||
tbl : 'a H.t;
|
||||
stack : Stack.t;
|
||||
}
|
||||
|
||||
let create ?(size=256) stack = {tbl = H.create size; stack; }
|
||||
|
||||
let mem {tbl; _} x = H.mem tbl x
|
||||
let find {tbl; _} k = H.find tbl k
|
||||
|
||||
let add t k v =
|
||||
Stack.register2 t.stack H.remove t.tbl k;
|
||||
H.add t.tbl k v
|
||||
|
||||
let remove t k =
|
||||
try
|
||||
let v = find t k in
|
||||
Stack.register3 t.stack H.add t.tbl k v;
|
||||
H.remove t.tbl k
|
||||
with Not_found -> ()
|
||||
|
||||
let fold t f acc = H.fold f t.tbl acc
|
||||
|
||||
let iter f t = H.iter f t.tbl
|
||||
end
|
||||
|
||||
77
src/util/backtrack.mli
Normal file
77
src/util/backtrack.mli
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
|
||||
(** Provides helpers for backtracking.
|
||||
This module defines backtracking stacks, i.e stacks of undo actions
|
||||
to perform when backtracking to a certain point. This allows for
|
||||
side-effect backtracking, and so to have backtracking automatically
|
||||
handled by extensions without the need for explicit synchronisation
|
||||
between the dispatcher and the extensions.
|
||||
*)
|
||||
|
||||
module Stack : sig
|
||||
(** A backtracking stack is a stack of undo actions to perform
|
||||
in order to revert back to a (mutable) state. *)
|
||||
|
||||
type t
|
||||
(** The type for a stack. *)
|
||||
|
||||
type level
|
||||
(** The type of backtracking point. *)
|
||||
|
||||
val create : unit -> t
|
||||
(** Creates an empty stack. *)
|
||||
|
||||
val dummy_level : level
|
||||
(** A dummy level. *)
|
||||
|
||||
val push : t -> unit
|
||||
(** Creates a backtracking point at the top of the stack. *)
|
||||
|
||||
val pop : t -> unit
|
||||
(** Pop all actions in the undo stack until the first backtracking point. *)
|
||||
|
||||
val level : t -> level
|
||||
(** Insert a named backtracking point at the top of the stack. *)
|
||||
|
||||
val backtrack : t -> level -> unit
|
||||
(** Backtrack to the given named backtracking point. *)
|
||||
|
||||
val register_undo : t -> (unit -> unit) -> unit
|
||||
(** Adds a callback at the top of the stack. *)
|
||||
|
||||
val register1 : t -> ('a -> unit) -> 'a -> unit
|
||||
val register2 : t -> ('a -> 'b -> unit) -> 'a -> 'b -> unit
|
||||
val register3 : t -> ('a -> 'b -> 'c -> unit) -> 'a -> 'b -> 'c -> unit
|
||||
(** Register functions to be called on the given arguments at the top of the stack.
|
||||
Allows to save some space by not creating too much closure as would be the case if
|
||||
only [unit -> unit] callbacks were stored. *)
|
||||
|
||||
val register_set : t -> 'a ref -> 'a -> unit
|
||||
(** Registers a ref to be set to the given value upon backtracking. *)
|
||||
|
||||
end
|
||||
|
||||
module Hashtbl :
|
||||
functor (K : Hashtbl.HashedType) ->
|
||||
sig
|
||||
(** Provides wrappers around hastables in order to have
|
||||
very simple integration with backtraking stacks.
|
||||
All actions performed on this table register the corresponding
|
||||
undo operations so that backtracking is automatic. *)
|
||||
|
||||
type key = K.t
|
||||
(** The type of keys of the Hashtbl. *)
|
||||
|
||||
type 'a t
|
||||
(** The type of hastable from keys to values of type ['a]. *)
|
||||
|
||||
val create : ?size:int -> Stack.t -> 'a t
|
||||
(** Creates an empty hashtable, that registers undo operations on the given stack. *)
|
||||
|
||||
val add : 'a t -> key -> 'a -> unit
|
||||
val mem : 'a t -> key -> bool
|
||||
val find : 'a t -> key -> 'a
|
||||
val remove : 'a t -> key -> unit
|
||||
val iter : (key -> 'a -> unit) -> 'a t -> unit
|
||||
val fold : 'a t -> (key -> 'a -> 'b -> 'b) -> 'b -> 'b
|
||||
(** Usual operations on the hashtabl. For more information see the Hashtbl module of the stdlib. *)
|
||||
end
|
||||
Loading…
Add table
Reference in a new issue