ocaml-containers/misc/cC.ml
Simon Cruanes 4bc6c8a008 split into package core (no pack, 'CC' prefix, stable)
and misc where oneshot ideas go
2014-05-16 20:58:28 +02:00

494 lines
16 KiB
OCaml

(*
Copyright (c) 2013, Simon Cruanes
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer. Redistributions in binary
form must reproduce the above copyright notice, this list of conditions and the
following disclaimer in the documentation and/or other materials provided with
the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*)
(** {1 Functional Congruence Closure} *)
(** This implementation follows more or less the paper
"fast congruence closure and extensions" by Nieuwenhuis & Oliveras.
It uses semi-persistent data structures but still thrives for efficiency. *)
(** {2 Curryfied terms} *)
module type CurryfiedTerm = sig
type symbol
type t = private {
shape : shape; (** Which kind of term is it? *)
tag : int; (** Unique ID *)
} (** A curryfied term *)
and shape = private
| Const of symbol (** Constant *)
| Apply of t * t (** Curryfied application *)
val mk_const : symbol -> t
val mk_app : t -> t -> t
val get_id : t -> int
val eq : t -> t -> bool
val pp_skel : out_channel -> t -> unit (* print tags recursively *)
end
module Curryfy(X : Hashtbl.HashedType) = struct
type symbol = X.t
type t = {
shape : shape; (** Which kind of term is it? *)
tag : int; (** Unique ID *)
}
and shape =
| Const of symbol (** Constant *)
| Apply of t * t (** Curryfied application *)
type term = t
module WE = Weak.Make(struct
type t = term
let equal a b = match a.shape, b.shape with
| Const ia, Const ib -> X.equal ia ib
| Apply (a1,a2), Apply (b1,b2) -> a1 == b1 && a2 == b2
| _ -> false
let hash a = match a.shape with
| Const i -> X.hash i
| Apply (a, b) -> a.tag * 65599 + b.tag
end)
let __table = WE.create 10001
let count = ref 0
let hashcons t =
let t' = WE.merge __table t in
(if t == t' then incr count);
t'
let mk_const i =
let t = {shape=Const i; tag= !count; } in
hashcons t
let mk_app a b =
let t = {shape=Apply (a, b); tag= !count; } in
hashcons t
let get_id t = t.tag
let eq t1 t2 = t1 == t2
let rec pp_skel oc t = match t.shape with
| Const _ -> Printf.fprintf oc "%d" t.tag
| Apply (t1, t2) ->
Printf.fprintf oc "(%a %a):%d" pp_skel t1 pp_skel t2 t.tag
end
(** {2 Congruence Closure} *)
module type S = sig
module CT : CurryfiedTerm
type t
(** Congruence Closure instance *)
exception Inconsistent of t * CT.t * CT.t * CT.t * CT.t
(** Exception raised when equality and inequality constraints are
inconsistent. [Inconsistent (a, b, a', b')] means that [a=b, a=a', b=b'] in
the congruence closure, but [a' != b'] was asserted before. *)
val create : int -> t
(** Create an empty CC of given size *)
val eq : t -> CT.t -> CT.t -> bool
(** Check whether the two terms are equal *)
val merge : t -> CT.t -> CT.t -> t
(** Assert that the two terms are equal (may raise Inconsistent) *)
val distinct : t -> CT.t -> CT.t -> t
(** Assert that the two given terms are distinct (may raise Inconsistent) *)
type action =
| Merge of CT.t * CT.t
| Distinct of CT.t * CT.t
(** Action that can be performed on the CC *)
val do_action : t -> action -> t
(** Perform the given action (may raise Inconsistent) *)
val can_eq : t -> CT.t -> CT.t -> bool
(** Check whether the two terms can be equal *)
val iter_equiv_class : t -> CT.t -> (CT.t -> unit) -> unit
(** Iterate on terms that are congruent to the given term *)
type explanation =
| ByCongruence of CT.t * CT.t (* direct congruence of terms *)
| ByMerge of CT.t * CT.t (* user merge of terms *)
val explain : t -> CT.t -> CT.t -> explanation list
(** Explain why those two terms are equal (assuming they are,
otherwise raises Invalid_argument) by returning a list
of merges. *)
end
module Make(T : CurryfiedTerm) = struct
module CT = T
module BV = Puf.PBitVector
module Puf = Puf.Make(CT)
module HashedCT = struct
type t = CT.t
let equal t1 t2 = t1.CT.tag = t2.CT.tag
let hash t = t.CT.tag
end
(* Persistent Hashtable on curryfied terms *)
module THashtbl = CCPersistentHashtbl.Make(HashedCT)
(* Persistent Hashtable on pairs of curryfied terms *)
module T2Hashtbl = CCPersistentHashtbl.Make(struct
type t = CT.t * CT.t
let equal (t1,t1') (t2,t2') = t1.CT.tag = t2.CT.tag && t1'.CT.tag = t2'.CT.tag
let hash (t,t') = t.CT.tag * 65599 + t'.CT.tag
end)
type t = {
uf : pending_eqn Puf.t; (* representatives for terms *)
defined : BV.t; (* is the term defined? *)
use : eqn list THashtbl.t; (* for all repr a, a -> all a@b=c and b@a=c *)
lookup : eqn T2Hashtbl.t; (* for all reprs a,b, some a@b=c (if any) *)
inconsistent : (CT.t * CT.t) option;
} (** Congruence Closure data structure *)
and eqn =
| EqnSimple of CT.t * CT.t (* t1 = t2 *)
| EqnApply of CT.t * CT.t * CT.t (* (t1 @ t2) = t3 *)
(** Equation between two terms *)
and pending_eqn =
| PendingSimple of eqn
| PendingDouble of eqn * eqn
exception Inconsistent of t * CT.t * CT.t * CT.t * CT.t
(** Exception raised when equality and inequality constraints are
inconsistent. [Inconsistent (a, b, a', b')] means that [a=b, a=a', b=b'] in
the congruence closure, but [a' != b'] was asserted before. *)
(** Create an empty CC of given size *)
let create size =
{ uf = Puf.create size;
defined = BV.make 3;
use = THashtbl.create size;
lookup = T2Hashtbl.create size;
inconsistent = None;
}
let mem cc t =
BV.get cc.defined t.CT.tag
let is_const t = match t.CT.shape with
| CT.Const _ -> true
| CT.Apply _ -> false
(** Merge equations in the congruence closure structure. [q] is a list
of [eqn], processed in FIFO order. May raise Inconsistent. *)
let rec merge cc eqn = match eqn with
| EqnSimple (a, b) ->
(* a=b, just propagate *)
propagate cc [PendingSimple eqn]
| EqnApply (a1, a2, a) ->
(* (a1 @ a2) = a *)
let a1' = Puf.find cc.uf a1 in
let a2' = Puf.find cc.uf a2 in
begin try
(* eqn' is (b1 @ b2) = b for some b1=a1', b2=a2' *)
let eqn' = T2Hashtbl.find cc.lookup (a1', a2') in
(* merge a and b because of eqn and eqn' *)
propagate cc [PendingDouble (eqn, eqn')]
with Not_found ->
(* remember that a1' @ a2' = a *)
let lookup = T2Hashtbl.replace cc.lookup (a1', a2') eqn in
let use_a1' = try THashtbl.find cc.use a1' with Not_found -> [] in
let use_a2' = try THashtbl.find cc.use a2' with Not_found -> [] in
let use = THashtbl.replace cc.use a1' (eqn::use_a1') in
let use = THashtbl.replace use a2' (eqn::use_a2') in
{ cc with use; lookup; }
end
(* propagate: merge pending equations *)
and propagate cc eqns =
let pending = ref eqns in
let uf = ref cc.uf in
let use = ref cc.use in
let lookup = ref cc.lookup in
(* process each pending equation *)
while !pending <> [] do
let eqn = List.hd !pending in
pending := List.tl !pending;
(* extract the two merged terms *)
let a, b = match eqn with
| PendingSimple (EqnSimple (a, b)) -> a, b
| PendingDouble (EqnApply (a1,a2,a), EqnApply (b1,b2,b)) -> a, b
| _ -> assert false
in
let a' = Puf.find !uf a in
let b' = Puf.find !uf b in
if not (CT.eq a' b') then begin
let use_a' = try THashtbl.find !use a' with Not_found -> [] in
let use_b' = try THashtbl.find !use b' with Not_found -> [] in
(* merge a and b's equivalence classes *)
(* Format.printf "merge %d %d@." a.CT.tag b.CT.tag; *)
uf := Puf.union !uf a b eqn;
(* check which of [a'] and [b'] is the new representative. [repr] is
the new representative, and [other] is the former representative *)
let repr = Puf.find !uf a' in
let use_repr = ref (if CT.eq repr a' then use_a' else use_b') in
let use_other = if CT.eq repr a' then use_b' else use_a' in
(* consider all c1@c2=c in use(a') *)
List.iter
(fun eqn -> match eqn with
| EqnSimple _ -> ()
| EqnApply (c1, c2, c) ->
let c1' = Puf.find !uf c1 in
let c2' = Puf.find !uf c2 in
begin try
let eqn' = T2Hashtbl.find !lookup (c1', c2') in
(* merge eqn with eqn', by congruence *)
pending := (PendingDouble (eqn,eqn')) :: !pending
with Not_found ->
lookup := T2Hashtbl.replace !lookup (c1', c2') eqn;
use_repr := eqn :: !use_repr;
end)
use_other;
(* update use list of [repr] *)
use := THashtbl.replace !use repr !use_repr;
(* check for inconsistencies *)
match Puf.inconsistent !uf with
| None -> () (* consistent *)
| Some (t1, t2, t1', t2') ->
(* inconsistent *)
let cc = { cc with use= !use; lookup= !lookup; uf= !uf; } in
raise (Inconsistent (cc, t1, t2, t1', t2'))
end
done;
let cc = { cc with use= !use; lookup= !lookup; uf= !uf; } in
cc
(** Add the given term to the CC *)
let rec add cc t =
match t.CT.shape with
| CT.Const _ ->
cc (* always trivially defined *)
| CT.Apply (t1, t2) ->
if BV.get cc.defined t.CT.tag
then cc (* already defined *)
else begin
(* note that [t] is defined, add it to the UF to avoid GC *)
let defined = BV.set_true cc.defined t.CT.tag in
let cc = {cc with defined; } in
(* recursive add. invariant: if a term is added, then its subterms
also are (hence the base case of constants or already added terms). *)
let cc = add cc t1 in
let cc = add cc t2 in
let cc = merge cc (EqnApply (t1, t2, t)) in
cc
end
(** Check whether the two terms are equal *)
let eq cc t1 t2 =
let cc = add (add cc t1) t2 in
let t1' = Puf.find cc.uf t1 in
let t2' = Puf.find cc.uf t2 in
CT.eq t1' t2'
(** Assert that the two terms are equal (may raise Inconsistent) *)
let merge cc t1 t2 =
let cc = add (add cc t1) t2 in
merge cc (EqnSimple (t1, t2))
(** Assert that the two given terms are distinct (may raise Inconsistent) *)
let distinct cc t1 t2 =
let cc = add (add cc t1) t2 in
let t1' = Puf.find cc.uf t1 in
let t2' = Puf.find cc.uf t2 in
if CT.eq t1' t2'
then raise (Inconsistent (cc, t1', t2', t1, t2)) (* they are equal, fail *)
else
(* remember that they should not become equal *)
let uf = Puf.distinct cc.uf t1 t2 in
{ cc with uf; }
type action =
| Merge of CT.t * CT.t
| Distinct of CT.t * CT.t
(** Action that can be performed on the CC *)
let do_action cc action = match action with
| Merge (t1, t2) -> merge cc t1 t2
| Distinct (t1, t2) -> distinct cc t1 t2
(** Check whether the two terms can be equal *)
let can_eq cc t1 t2 =
let cc = add (add cc t1) t2 in
not (Puf.must_be_distinct cc.uf t1 t2)
(** Iterate on terms that are congruent to the given term *)
let iter_equiv_class cc t f =
Puf.iter_equiv_class cc.uf t f
(** {3 Auxilliary Union-find for explanations} *)
module SparseUF = struct
module H = Hashtbl.Make(HashedCT)
type t = uf_ref H.t
and uf_ref = {
term : CT.t;
mutable parent : CT.t;
mutable highest_node : CT.t;
} (** Union-find reference *)
let create size = H.create size
let get_ref uf t =
try H.find uf t
with Not_found ->
let r_t = { term=t; parent=t; highest_node=t; } in
H.add uf t r_t;
r_t
let rec find_ref uf r_t =
if CT.eq r_t.parent r_t.term
then r_t (* fixpoint *)
else
let r_t' = get_ref uf r_t.parent in
find_ref uf r_t' (* recurse (no path compression) *)
let find uf t =
try
let r_t = H.find uf t in
(find_ref uf r_t).term
with Not_found ->
t
let eq uf t1 t2 =
CT.eq (find uf t1) (find uf t2)
let highest_node uf t =
try
let r_t = H.find uf t in
(find_ref uf r_t).highest_node
with Not_found ->
t
(* oriented union (t1 -> t2), assuming t2 is "higher" than t1 *)
let union uf t1 t2 =
let r_t1' = find_ref uf (get_ref uf t1) in
let r_t2' = find_ref uf (get_ref uf t2) in
r_t1'.parent <- r_t2'.term
end
(** {3 Producing explanations} *)
type explanation =
| ByCongruence of CT.t * CT.t (* direct congruence of terms *)
| ByMerge of CT.t * CT.t (* user merge of terms *)
(** Explain why those two terms are equal (they must be) *)
let explain cc t1 t2 =
assert (eq cc t1 t2);
(* keeps track of which equalities are already explained *)
let explained = SparseUF.create 5 in
let explanations = ref [] in
(* equations waiting to be explained *)
let pending = Queue.create () in
Queue.push (t1,t2) pending;
(* explain why a=c, where c is the root of the proof forest a belongs to *)
let rec explain_along a c =
let a' = SparseUF.highest_node explained a in
if CT.eq a' c then ()
else match Puf.explain_step cc.uf a' with
| None -> assert (CT.eq a' c)
| Some (b, e) ->
(* a->b on the path from a to c *)
begin match e with
| PendingSimple (EqnSimple (a',b')) ->
explanations := ByMerge (a', b') :: !explanations
| PendingDouble (EqnApply (a1, a2, a'), EqnApply (b1, b2, b')) ->
explanations := ByCongruence (a', b') :: !explanations;
Queue.push (a1, b1) pending;
Queue.push (a2, b2) pending;
| _ -> assert false
end;
(* now a' = b is justified *)
SparseUF.union explained a' b;
(* recurse *)
let new_a = SparseUF.highest_node explained b in
explain_along new_a c
in
(* process pending equations *)
while not (Queue.is_empty pending) do
let a, b = Queue.pop pending in
if SparseUF.eq explained a b
then ()
else begin
let c = Puf.common_ancestor cc.uf a b in
explain_along a c;
explain_along b c;
end
done;
!explanations
end
module StrTerm = Curryfy(struct
type t = string
let equal s1 s2 = s1 = s2
let hash s = Hashtbl.hash s
end)
module StrCC = Make(StrTerm)
let lex str =
let lexer = Genlex.make_lexer ["("; ")"] in
lexer (Stream.of_string str)
let parse str =
let stream = lex str in
let rec parse_term () =
match Stream.peek stream with
| Some (Genlex.Kwd "(") ->
Stream.junk stream;
let t1 = parse_term () in
let t2 = parse_term () in
begin match Stream.peek stream with
| Some (Genlex.Kwd ")") ->
Stream.junk stream;
StrTerm.mk_app t1 t2 (* end apply *)
| _ -> raise (Stream.Error "expected )")
end
| Some (Genlex.Ident s) ->
Stream.junk stream;
StrTerm.mk_const s
| _ -> raise (Stream.Error "expected term")
in
parse_term ()
let rec pp fmt t =
match t.StrTerm.shape with
| StrTerm.Const s ->
Format.fprintf fmt "%s:%d" s t.StrTerm.tag
| StrTerm.Apply (t1, t2) ->
Format.fprintf fmt "(%a %a):%d" pp t1 pp t2 t.StrTerm.tag