mirror of
https://github.com/c-cube/ocaml-containers.git
synced 2025-12-06 11:15:31 -05:00
494 lines
16 KiB
OCaml
494 lines
16 KiB
OCaml
(*
|
|
Copyright (c) 2013, Simon Cruanes
|
|
All rights reserved.
|
|
|
|
Redistribution and use in source and binary forms, with or without
|
|
modification, are permitted provided that the following conditions are met:
|
|
|
|
Redistributions of source code must retain the above copyright notice, this
|
|
list of conditions and the following disclaimer. Redistributions in binary
|
|
form must reproduce the above copyright notice, this list of conditions and the
|
|
following disclaimer in the documentation and/or other materials provided with
|
|
the distribution.
|
|
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
|
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
|
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*)
|
|
|
|
(** {1 Functional Congruence Closure} *)
|
|
|
|
(** This implementation follows more or less the paper
|
|
"fast congruence closure and extensions" by Nieuwenhuis & Oliveras.
|
|
It uses semi-persistent data structures but still thrives for efficiency. *)
|
|
|
|
(** {2 Curryfied terms} *)
|
|
|
|
module type CurryfiedTerm = sig
|
|
type symbol
|
|
type t = private {
|
|
shape : shape; (** Which kind of term is it? *)
|
|
tag : int; (** Unique ID *)
|
|
} (** A curryfied term *)
|
|
and shape = private
|
|
| Const of symbol (** Constant *)
|
|
| Apply of t * t (** Curryfied application *)
|
|
|
|
val mk_const : symbol -> t
|
|
val mk_app : t -> t -> t
|
|
val get_id : t -> int
|
|
val eq : t -> t -> bool
|
|
val pp_skel : out_channel -> t -> unit (* print tags recursively *)
|
|
end
|
|
|
|
module Curryfy(X : Hashtbl.HashedType) = struct
|
|
type symbol = X.t
|
|
type t = {
|
|
shape : shape; (** Which kind of term is it? *)
|
|
tag : int; (** Unique ID *)
|
|
}
|
|
and shape =
|
|
| Const of symbol (** Constant *)
|
|
| Apply of t * t (** Curryfied application *)
|
|
|
|
type term = t
|
|
|
|
module WE = Weak.Make(struct
|
|
type t = term
|
|
let equal a b = match a.shape, b.shape with
|
|
| Const ia, Const ib -> X.equal ia ib
|
|
| Apply (a1,a2), Apply (b1,b2) -> a1 == b1 && a2 == b2
|
|
| _ -> false
|
|
let hash a = match a.shape with
|
|
| Const i -> X.hash i
|
|
| Apply (a, b) -> a.tag * 65599 + b.tag
|
|
end)
|
|
|
|
let __table = WE.create 10001
|
|
let count = ref 0
|
|
|
|
let hashcons t =
|
|
let t' = WE.merge __table t in
|
|
(if t == t' then incr count);
|
|
t'
|
|
|
|
let mk_const i =
|
|
let t = {shape=Const i; tag= !count; } in
|
|
hashcons t
|
|
|
|
let mk_app a b =
|
|
let t = {shape=Apply (a, b); tag= !count; } in
|
|
hashcons t
|
|
|
|
let get_id t = t.tag
|
|
|
|
let eq t1 t2 = t1 == t2
|
|
|
|
let rec pp_skel oc t = match t.shape with
|
|
| Const _ -> Printf.fprintf oc "%d" t.tag
|
|
| Apply (t1, t2) ->
|
|
Printf.fprintf oc "(%a %a):%d" pp_skel t1 pp_skel t2 t.tag
|
|
end
|
|
|
|
(** {2 Congruence Closure} *)
|
|
|
|
module type S = sig
|
|
module CT : CurryfiedTerm
|
|
|
|
type t
|
|
(** Congruence Closure instance *)
|
|
|
|
exception Inconsistent of t * CT.t * CT.t * CT.t * CT.t
|
|
(** Exception raised when equality and inequality constraints are
|
|
inconsistent. [Inconsistent (a, b, a', b')] means that [a=b, a=a', b=b'] in
|
|
the congruence closure, but [a' != b'] was asserted before. *)
|
|
|
|
val create : int -> t
|
|
(** Create an empty CC of given size *)
|
|
|
|
val eq : t -> CT.t -> CT.t -> bool
|
|
(** Check whether the two terms are equal *)
|
|
|
|
val merge : t -> CT.t -> CT.t -> t
|
|
(** Assert that the two terms are equal (may raise Inconsistent) *)
|
|
|
|
val distinct : t -> CT.t -> CT.t -> t
|
|
(** Assert that the two given terms are distinct (may raise Inconsistent) *)
|
|
|
|
type action =
|
|
| Merge of CT.t * CT.t
|
|
| Distinct of CT.t * CT.t
|
|
(** Action that can be performed on the CC *)
|
|
|
|
val do_action : t -> action -> t
|
|
(** Perform the given action (may raise Inconsistent) *)
|
|
|
|
val can_eq : t -> CT.t -> CT.t -> bool
|
|
(** Check whether the two terms can be equal *)
|
|
|
|
val iter_equiv_class : t -> CT.t -> (CT.t -> unit) -> unit
|
|
(** Iterate on terms that are congruent to the given term *)
|
|
|
|
type explanation =
|
|
| ByCongruence of CT.t * CT.t (* direct congruence of terms *)
|
|
| ByMerge of CT.t * CT.t (* user merge of terms *)
|
|
|
|
val explain : t -> CT.t -> CT.t -> explanation list
|
|
(** Explain why those two terms are equal (assuming they are,
|
|
otherwise raises Invalid_argument) by returning a list
|
|
of merges. *)
|
|
end
|
|
|
|
module Make(T : CurryfiedTerm) = struct
|
|
module CT = T
|
|
module BV = Puf.PBitVector
|
|
module Puf = Puf.Make(CT)
|
|
|
|
module HashedCT = struct
|
|
type t = CT.t
|
|
let equal t1 t2 = t1.CT.tag = t2.CT.tag
|
|
let hash t = t.CT.tag
|
|
end
|
|
|
|
(* Persistent Hashtable on curryfied terms *)
|
|
module THashtbl = CCPersistentHashtbl.Make(HashedCT)
|
|
|
|
(* Persistent Hashtable on pairs of curryfied terms *)
|
|
module T2Hashtbl = CCPersistentHashtbl.Make(struct
|
|
type t = CT.t * CT.t
|
|
let equal (t1,t1') (t2,t2') = t1.CT.tag = t2.CT.tag && t1'.CT.tag = t2'.CT.tag
|
|
let hash (t,t') = t.CT.tag * 65599 + t'.CT.tag
|
|
end)
|
|
|
|
type t = {
|
|
uf : pending_eqn Puf.t; (* representatives for terms *)
|
|
defined : BV.t; (* is the term defined? *)
|
|
use : eqn list THashtbl.t; (* for all repr a, a -> all a@b=c and b@a=c *)
|
|
lookup : eqn T2Hashtbl.t; (* for all reprs a,b, some a@b=c (if any) *)
|
|
inconsistent : (CT.t * CT.t) option;
|
|
} (** Congruence Closure data structure *)
|
|
and eqn =
|
|
| EqnSimple of CT.t * CT.t (* t1 = t2 *)
|
|
| EqnApply of CT.t * CT.t * CT.t (* (t1 @ t2) = t3 *)
|
|
(** Equation between two terms *)
|
|
and pending_eqn =
|
|
| PendingSimple of eqn
|
|
| PendingDouble of eqn * eqn
|
|
|
|
exception Inconsistent of t * CT.t * CT.t * CT.t * CT.t
|
|
(** Exception raised when equality and inequality constraints are
|
|
inconsistent. [Inconsistent (a, b, a', b')] means that [a=b, a=a', b=b'] in
|
|
the congruence closure, but [a' != b'] was asserted before. *)
|
|
|
|
(** Create an empty CC of given size *)
|
|
let create size =
|
|
{ uf = Puf.create size;
|
|
defined = BV.make 3;
|
|
use = THashtbl.create size;
|
|
lookup = T2Hashtbl.create size;
|
|
inconsistent = None;
|
|
}
|
|
|
|
let mem cc t =
|
|
BV.get cc.defined t.CT.tag
|
|
|
|
let is_const t = match t.CT.shape with
|
|
| CT.Const _ -> true
|
|
| CT.Apply _ -> false
|
|
|
|
(** Merge equations in the congruence closure structure. [q] is a list
|
|
of [eqn], processed in FIFO order. May raise Inconsistent. *)
|
|
let rec merge cc eqn = match eqn with
|
|
| EqnSimple (a, b) ->
|
|
(* a=b, just propagate *)
|
|
propagate cc [PendingSimple eqn]
|
|
| EqnApply (a1, a2, a) ->
|
|
(* (a1 @ a2) = a *)
|
|
let a1' = Puf.find cc.uf a1 in
|
|
let a2' = Puf.find cc.uf a2 in
|
|
begin try
|
|
(* eqn' is (b1 @ b2) = b for some b1=a1', b2=a2' *)
|
|
let eqn' = T2Hashtbl.find cc.lookup (a1', a2') in
|
|
(* merge a and b because of eqn and eqn' *)
|
|
propagate cc [PendingDouble (eqn, eqn')]
|
|
with Not_found ->
|
|
(* remember that a1' @ a2' = a *)
|
|
let lookup = T2Hashtbl.replace cc.lookup (a1', a2') eqn in
|
|
let use_a1' = try THashtbl.find cc.use a1' with Not_found -> [] in
|
|
let use_a2' = try THashtbl.find cc.use a2' with Not_found -> [] in
|
|
let use = THashtbl.replace cc.use a1' (eqn::use_a1') in
|
|
let use = THashtbl.replace use a2' (eqn::use_a2') in
|
|
{ cc with use; lookup; }
|
|
end
|
|
(* propagate: merge pending equations *)
|
|
and propagate cc eqns =
|
|
let pending = ref eqns in
|
|
let uf = ref cc.uf in
|
|
let use = ref cc.use in
|
|
let lookup = ref cc.lookup in
|
|
(* process each pending equation *)
|
|
while !pending <> [] do
|
|
let eqn = List.hd !pending in
|
|
pending := List.tl !pending;
|
|
(* extract the two merged terms *)
|
|
let a, b = match eqn with
|
|
| PendingSimple (EqnSimple (a, b)) -> a, b
|
|
| PendingDouble (EqnApply (a1,a2,a), EqnApply (b1,b2,b)) -> a, b
|
|
| _ -> assert false
|
|
in
|
|
let a' = Puf.find !uf a in
|
|
let b' = Puf.find !uf b in
|
|
if not (CT.eq a' b') then begin
|
|
let use_a' = try THashtbl.find !use a' with Not_found -> [] in
|
|
let use_b' = try THashtbl.find !use b' with Not_found -> [] in
|
|
(* merge a and b's equivalence classes *)
|
|
(* Format.printf "merge %d %d@." a.CT.tag b.CT.tag; *)
|
|
uf := Puf.union !uf a b eqn;
|
|
(* check which of [a'] and [b'] is the new representative. [repr] is
|
|
the new representative, and [other] is the former representative *)
|
|
let repr = Puf.find !uf a' in
|
|
let use_repr = ref (if CT.eq repr a' then use_a' else use_b') in
|
|
let use_other = if CT.eq repr a' then use_b' else use_a' in
|
|
(* consider all c1@c2=c in use(a') *)
|
|
List.iter
|
|
(fun eqn -> match eqn with
|
|
| EqnSimple _ -> ()
|
|
| EqnApply (c1, c2, c) ->
|
|
let c1' = Puf.find !uf c1 in
|
|
let c2' = Puf.find !uf c2 in
|
|
begin try
|
|
let eqn' = T2Hashtbl.find !lookup (c1', c2') in
|
|
(* merge eqn with eqn', by congruence *)
|
|
pending := (PendingDouble (eqn,eqn')) :: !pending
|
|
with Not_found ->
|
|
lookup := T2Hashtbl.replace !lookup (c1', c2') eqn;
|
|
use_repr := eqn :: !use_repr;
|
|
end)
|
|
use_other;
|
|
(* update use list of [repr] *)
|
|
use := THashtbl.replace !use repr !use_repr;
|
|
(* check for inconsistencies *)
|
|
match Puf.inconsistent !uf with
|
|
| None -> () (* consistent *)
|
|
| Some (t1, t2, t1', t2') ->
|
|
(* inconsistent *)
|
|
let cc = { cc with use= !use; lookup= !lookup; uf= !uf; } in
|
|
raise (Inconsistent (cc, t1, t2, t1', t2'))
|
|
end
|
|
done;
|
|
let cc = { cc with use= !use; lookup= !lookup; uf= !uf; } in
|
|
cc
|
|
|
|
(** Add the given term to the CC *)
|
|
let rec add cc t =
|
|
match t.CT.shape with
|
|
| CT.Const _ ->
|
|
cc (* always trivially defined *)
|
|
| CT.Apply (t1, t2) ->
|
|
if BV.get cc.defined t.CT.tag
|
|
then cc (* already defined *)
|
|
else begin
|
|
(* note that [t] is defined, add it to the UF to avoid GC *)
|
|
let defined = BV.set_true cc.defined t.CT.tag in
|
|
let cc = {cc with defined; } in
|
|
(* recursive add. invariant: if a term is added, then its subterms
|
|
also are (hence the base case of constants or already added terms). *)
|
|
let cc = add cc t1 in
|
|
let cc = add cc t2 in
|
|
let cc = merge cc (EqnApply (t1, t2, t)) in
|
|
cc
|
|
end
|
|
|
|
(** Check whether the two terms are equal *)
|
|
let eq cc t1 t2 =
|
|
let cc = add (add cc t1) t2 in
|
|
let t1' = Puf.find cc.uf t1 in
|
|
let t2' = Puf.find cc.uf t2 in
|
|
CT.eq t1' t2'
|
|
|
|
(** Assert that the two terms are equal (may raise Inconsistent) *)
|
|
let merge cc t1 t2 =
|
|
let cc = add (add cc t1) t2 in
|
|
merge cc (EqnSimple (t1, t2))
|
|
|
|
(** Assert that the two given terms are distinct (may raise Inconsistent) *)
|
|
let distinct cc t1 t2 =
|
|
let cc = add (add cc t1) t2 in
|
|
let t1' = Puf.find cc.uf t1 in
|
|
let t2' = Puf.find cc.uf t2 in
|
|
if CT.eq t1' t2'
|
|
then raise (Inconsistent (cc, t1', t2', t1, t2)) (* they are equal, fail *)
|
|
else
|
|
(* remember that they should not become equal *)
|
|
let uf = Puf.distinct cc.uf t1 t2 in
|
|
{ cc with uf; }
|
|
|
|
type action =
|
|
| Merge of CT.t * CT.t
|
|
| Distinct of CT.t * CT.t
|
|
(** Action that can be performed on the CC *)
|
|
|
|
let do_action cc action = match action with
|
|
| Merge (t1, t2) -> merge cc t1 t2
|
|
| Distinct (t1, t2) -> distinct cc t1 t2
|
|
|
|
(** Check whether the two terms can be equal *)
|
|
let can_eq cc t1 t2 =
|
|
let cc = add (add cc t1) t2 in
|
|
not (Puf.must_be_distinct cc.uf t1 t2)
|
|
|
|
(** Iterate on terms that are congruent to the given term *)
|
|
let iter_equiv_class cc t f =
|
|
Puf.iter_equiv_class cc.uf t f
|
|
|
|
(** {3 Auxilliary Union-find for explanations} *)
|
|
|
|
module SparseUF = struct
|
|
module H = Hashtbl.Make(HashedCT)
|
|
|
|
type t = uf_ref H.t
|
|
and uf_ref = {
|
|
term : CT.t;
|
|
mutable parent : CT.t;
|
|
mutable highest_node : CT.t;
|
|
} (** Union-find reference *)
|
|
|
|
let create size = H.create size
|
|
|
|
let get_ref uf t =
|
|
try H.find uf t
|
|
with Not_found ->
|
|
let r_t = { term=t; parent=t; highest_node=t; } in
|
|
H.add uf t r_t;
|
|
r_t
|
|
|
|
let rec find_ref uf r_t =
|
|
if CT.eq r_t.parent r_t.term
|
|
then r_t (* fixpoint *)
|
|
else
|
|
let r_t' = get_ref uf r_t.parent in
|
|
find_ref uf r_t' (* recurse (no path compression) *)
|
|
|
|
let find uf t =
|
|
try
|
|
let r_t = H.find uf t in
|
|
(find_ref uf r_t).term
|
|
with Not_found ->
|
|
t
|
|
|
|
let eq uf t1 t2 =
|
|
CT.eq (find uf t1) (find uf t2)
|
|
|
|
let highest_node uf t =
|
|
try
|
|
let r_t = H.find uf t in
|
|
(find_ref uf r_t).highest_node
|
|
with Not_found ->
|
|
t
|
|
|
|
(* oriented union (t1 -> t2), assuming t2 is "higher" than t1 *)
|
|
let union uf t1 t2 =
|
|
let r_t1' = find_ref uf (get_ref uf t1) in
|
|
let r_t2' = find_ref uf (get_ref uf t2) in
|
|
r_t1'.parent <- r_t2'.term
|
|
end
|
|
|
|
(** {3 Producing explanations} *)
|
|
|
|
type explanation =
|
|
| ByCongruence of CT.t * CT.t (* direct congruence of terms *)
|
|
| ByMerge of CT.t * CT.t (* user merge of terms *)
|
|
|
|
(** Explain why those two terms are equal (they must be) *)
|
|
let explain cc t1 t2 =
|
|
assert (eq cc t1 t2);
|
|
(* keeps track of which equalities are already explained *)
|
|
let explained = SparseUF.create 5 in
|
|
let explanations = ref [] in
|
|
(* equations waiting to be explained *)
|
|
let pending = Queue.create () in
|
|
Queue.push (t1,t2) pending;
|
|
(* explain why a=c, where c is the root of the proof forest a belongs to *)
|
|
let rec explain_along a c =
|
|
let a' = SparseUF.highest_node explained a in
|
|
if CT.eq a' c then ()
|
|
else match Puf.explain_step cc.uf a' with
|
|
| None -> assert (CT.eq a' c)
|
|
| Some (b, e) ->
|
|
(* a->b on the path from a to c *)
|
|
begin match e with
|
|
| PendingSimple (EqnSimple (a',b')) ->
|
|
explanations := ByMerge (a', b') :: !explanations
|
|
| PendingDouble (EqnApply (a1, a2, a'), EqnApply (b1, b2, b')) ->
|
|
explanations := ByCongruence (a', b') :: !explanations;
|
|
Queue.push (a1, b1) pending;
|
|
Queue.push (a2, b2) pending;
|
|
| _ -> assert false
|
|
end;
|
|
(* now a' = b is justified *)
|
|
SparseUF.union explained a' b;
|
|
(* recurse *)
|
|
let new_a = SparseUF.highest_node explained b in
|
|
explain_along new_a c
|
|
in
|
|
(* process pending equations *)
|
|
while not (Queue.is_empty pending) do
|
|
let a, b = Queue.pop pending in
|
|
if SparseUF.eq explained a b
|
|
then ()
|
|
else begin
|
|
let c = Puf.common_ancestor cc.uf a b in
|
|
explain_along a c;
|
|
explain_along b c;
|
|
end
|
|
done;
|
|
!explanations
|
|
end
|
|
|
|
module StrTerm = Curryfy(struct
|
|
type t = string
|
|
let equal s1 s2 = s1 = s2
|
|
let hash s = Hashtbl.hash s
|
|
end)
|
|
|
|
module StrCC = Make(StrTerm)
|
|
|
|
let lex str =
|
|
let lexer = Genlex.make_lexer ["("; ")"] in
|
|
lexer (Stream.of_string str)
|
|
|
|
let parse str =
|
|
let stream = lex str in
|
|
let rec parse_term () =
|
|
match Stream.peek stream with
|
|
| Some (Genlex.Kwd "(") ->
|
|
Stream.junk stream;
|
|
let t1 = parse_term () in
|
|
let t2 = parse_term () in
|
|
begin match Stream.peek stream with
|
|
| Some (Genlex.Kwd ")") ->
|
|
Stream.junk stream;
|
|
StrTerm.mk_app t1 t2 (* end apply *)
|
|
| _ -> raise (Stream.Error "expected )")
|
|
end
|
|
| Some (Genlex.Ident s) ->
|
|
Stream.junk stream;
|
|
StrTerm.mk_const s
|
|
| _ -> raise (Stream.Error "expected term")
|
|
in
|
|
parse_term ()
|
|
|
|
let rec pp fmt t =
|
|
match t.StrTerm.shape with
|
|
| StrTerm.Const s ->
|
|
Format.fprintf fmt "%s:%d" s t.StrTerm.tag
|
|
| StrTerm.Apply (t1, t2) ->
|
|
Format.fprintf fmt "(%a %a):%d" pp t1 pp t2 t.StrTerm.tag
|
|
|