sidekick/src/mcsat/eclosure.ml
2016-09-22 18:31:22 +02:00

231 lines
5.6 KiB
OCaml

module type Key = sig
type t
val hash : t -> int
val equal : t -> t -> bool
val compare : t -> t -> int
val print : Format.formatter -> t -> unit
end
module type S = sig
type t
type var
exception Unsat of var * var * var list
val create : Backtrack.Stack.t -> t
val find : t -> var -> var
val add_eq : t -> var -> var -> unit
val add_neq : t -> var -> var -> unit
val add_tag : t -> var -> var -> unit
val find_tag : t -> var -> var * (var * var) option
end
module Make(T : Key) = struct
module M = Map.Make(T)
module H = Backtrack.Hashtbl(T)
type var = T.t
exception Equal of var * var
exception Same_tag of var * var
exception Unsat of var * var * var list
type repr_info = {
rank : int;
tag : (T.t * T.t) option;
forbidden : (var * var) M.t;
}
type node =
| Follow of var
| Repr of repr_info
type t = {
size : int H.t;
expl : var H.t;
repr : node H.t;
}
let create s = {
size = H.create s;
expl = H.create s;
repr = H.create s;
}
(* Union-find algorithm with path compression *)
let self_repr = Repr { rank = 0; tag = None; forbidden = M.empty }
let find_hash m i default =
try H.find m i
with Not_found -> default
let rec find_aux m i =
match find_hash m i self_repr with
| Repr r -> r, i
| Follow j ->
let r, k = find_aux m j in
H.add m i (Follow k);
r, k
let get_repr h x =
let r, y = find_aux h.repr x in
y, r
let tag h x v =
let r, y = find_aux h.repr x in
let new_m =
{ r with
tag = match r.tag with
| Some (_, v') when not (T.equal v v') -> raise (Equal (x, y))
| (Some _) as t -> t
| None -> Some (x, v) }
in
H.add h.repr y (Repr new_m)
let find h x = fst (get_repr h x)
let find_tag h x =
let r, y = find_aux h.repr x in
y, r.tag
let forbid_aux m x =
try
let a, b = M.find x m in
raise (Equal (a, b))
with Not_found -> ()
let link h x mx y my =
let new_m = {
rank = if mx.rank = my.rank then mx.rank + 1 else mx.rank;
tag = (match mx.tag, my.tag with
| Some (z, t1), Some (w, t2) ->
if not (T.equal t1 t2) then begin
Log.debugf 3 "Tag shenanigan : %a (%a) <> %a (%a)" (fun k ->
k T.print t1 T.print z T.print t2 T.print w);
raise (Equal (z, w))
end else Some (z, t1)
| Some t, None | None, Some t -> Some t
| None, None -> None);
forbidden = M.merge (fun _ b1 b2 -> match b1, b2 with
| Some r, _ | None, Some r -> Some r | _ -> assert false)
mx.forbidden my.forbidden;}
in
let aux m z eq =
match H.find m z with
| Repr r ->
let r' = { r with
forbidden = M.add x eq (M.remove y r.forbidden) }
in
H.add m z (Repr r')
| _ -> assert false
in
M.iter (aux h.repr) my.forbidden;
H.add h.repr y (Follow x);
H.add h.repr x (Repr new_m)
let union h x y =
let rx, mx = get_repr h x in
let ry, my = get_repr h y in
if T.compare rx ry <> 0 then begin
forbid_aux mx.forbidden ry;
forbid_aux my.forbidden rx;
if mx.rank > my.rank then begin
link h rx mx ry my
end else begin
link h ry my rx mx
end
end
let forbid h x y =
let rx, mx = get_repr h x in
let ry, my = get_repr h y in
if T.compare rx ry = 0 then
raise (Equal (x, y))
else match mx.tag, my.tag with
| Some (a, v), Some (b, v') when T.compare v v' = 0 ->
raise (Same_tag(a, b))
| _ ->
H.add h.repr ry (Repr { my with forbidden = M.add rx (x, y) my.forbidden });
H.add h.repr rx (Repr { mx with forbidden = M.add ry (x, y) mx.forbidden })
(* Equivalence closure with explanation output *)
let find_parent v m = find_hash m v v
let rec root m acc curr =
let parent = find_parent curr m in
if T.compare curr parent = 0 then
curr :: acc
else
root m (curr :: acc) parent
let rec rev_root m curr =
let next = find_parent curr m in
if T.compare curr next = 0 then
curr
else begin
H.remove m curr;
let res = rev_root m next in
H.add m next curr;
res
end
let expl t a b =
let rec aux last = function
| x :: r, y :: r' when T.compare x y = 0 ->
aux (Some x) (r, r')
| l, l' -> begin match last with
| Some z -> List.rev_append (z :: l) l'
| None -> List.rev_append l l'
end
in
aux None (root t.expl [] a, root t.expl [] b)
let add_eq_aux t i j =
if T.compare (find t i) (find t j) = 0 then
()
else begin
let old_root_i = rev_root t.expl i in
let old_root_j = rev_root t.expl j in
let nb_i = find_hash t.size old_root_i 0 in
let nb_j = find_hash t.size old_root_j 0 in
if nb_i < nb_j then begin
H.add t.expl i j;
H.add t.size j (nb_i + nb_j + 1)
end else begin
H.add t.expl j i;
H.add t.size i (nb_i + nb_j + 1)
end
end
(* Functions wrapped to produce explanation in case
* something went wrong *)
let add_tag t x v =
match tag t x v with
| () -> ()
| exception Equal (a, b) ->
raise (Unsat (a, b, expl t a b))
let add_eq t i j =
add_eq_aux t i j;
match union t i j with
| () -> ()
| exception Equal (a, b) ->
raise (Unsat (a, b, expl t a b))
let add_neq t i j =
match forbid t i j with
| () -> ()
| exception Equal (a, b) ->
raise (Unsat (a, b, expl t a b))
| exception Same_tag (x, y) ->
add_eq_aux t i j;
let res = expl t i j in
raise (Unsat (i, j, res))
end