mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-10 13:14:09 -05:00
321 lines
9.6 KiB
OCaml
321 lines
9.6 KiB
OCaml
module CC_view = Sidekick_core.CC_view
|
|
|
|
module type ARG = sig
|
|
module T : Sidekick_core.TERM
|
|
|
|
val cc_view : T.Term.t -> (T.Fun.t, T.Term.t, T.Term.t Iter.t) CC_view.t
|
|
end
|
|
|
|
module type S = sig
|
|
type term
|
|
type fun_
|
|
type term_state
|
|
|
|
type t
|
|
|
|
val create : term_state -> t
|
|
|
|
val clear : t -> unit
|
|
|
|
val add_lit : t -> term -> bool -> unit
|
|
|
|
val check_sat : t -> bool
|
|
|
|
val classes : t -> term Iter.t Iter.t
|
|
end
|
|
|
|
module Make(A: ARG) = struct
|
|
open CC_view
|
|
|
|
module Fun = A.T.Fun
|
|
module T = A.T.Term
|
|
type fun_ = A.T.Fun.t
|
|
type term = T.t
|
|
type term_state = T.state
|
|
|
|
module T_tbl = CCHashtbl.Make(T)
|
|
|
|
type node = {
|
|
n_t: term;
|
|
mutable n_next: node; (* next in class *)
|
|
mutable n_size: int; (* size of class *)
|
|
mutable n_parents: node list;
|
|
mutable n_root: node; (* root of the class *)
|
|
}
|
|
|
|
type signature = (fun_, node, node list) CC_view.t
|
|
|
|
module Node = struct
|
|
type t = node
|
|
let[@inline] equal (n1:t) n2 = T.equal n1.n_t n2.n_t
|
|
let[@inline] hash (n:t) = T.hash n.n_t
|
|
let[@inline] size (n:t) = n.n_size
|
|
let[@inline] is_root n = n == n.n_root
|
|
let[@inline] root n = n.n_root
|
|
let[@inline] term n = n.n_t
|
|
let pp out n = T.pp out n.n_t
|
|
|
|
let add_parent (self:t) ~p : unit =
|
|
self.n_parents <- p :: self.n_parents
|
|
|
|
let make (t:T.t) : t =
|
|
let rec n = {
|
|
n_t=t; n_size=1; n_next=n;
|
|
n_parents=[]; n_root=n;
|
|
} in
|
|
n
|
|
|
|
(* iterate over the class *)
|
|
let iter_cls (n0:t) f : unit =
|
|
let rec aux n =
|
|
f n;
|
|
let n' = n.n_next in
|
|
if equal n' n0 then () else aux n'
|
|
in
|
|
aux n0
|
|
end
|
|
|
|
module Signature = struct
|
|
type t = signature
|
|
let equal (s1:t) s2 : bool =
|
|
match s1, s2 with
|
|
| Bool b1, Bool b2 -> b1=b2
|
|
| App_fun (f1,[]), App_fun (f2,[]) -> Fun.equal f1 f2
|
|
| App_fun (f1,l1), App_fun (f2,l2) ->
|
|
Fun.equal f1 f2 && CCList.equal Node.equal l1 l2
|
|
| App_ho (f1,l1), App_ho (f2,l2) ->
|
|
Node.equal f1 f2 && CCList.equal Node.equal l1 l2
|
|
| Not n1, Not n2 -> Node.equal n1 n2
|
|
| If (a1,b1,c1), If (a2,b2,c2) ->
|
|
Node.equal a1 a2 && Node.equal b1 b2 && Node.equal c1 c2
|
|
| Eq (a1,b1), Eq (a2,b2) ->
|
|
Node.equal a1 a2 && Node.equal b1 b2
|
|
| Opaque u1, Opaque u2 -> Node.equal u1 u2
|
|
| Bool _, _ | App_fun _, _ | App_ho _, _ | If _, _
|
|
| Eq _, _ | Opaque _, _ | Not _, _
|
|
-> false
|
|
|
|
let hash (s:t) : int =
|
|
let module H = CCHash in
|
|
match s with
|
|
| Bool b -> H.combine2 10 (H.bool b)
|
|
| App_fun (f, l) -> H.combine3 20 (Fun.hash f) (H.list Node.hash l)
|
|
| App_ho (f, l) -> H.combine3 30 (Node.hash f) (H.list Node.hash l)
|
|
| Eq (a,b) -> H.combine3 40 (Node.hash a) (Node.hash b)
|
|
| Opaque u -> H.combine2 50 (Node.hash u)
|
|
| If (a,b,c) -> H.combine4 60 (Node.hash a)(Node.hash b)(Node.hash c)
|
|
| Not u -> H.combine2 70 (Node.hash u)
|
|
|
|
let pp out = function
|
|
| Bool b -> Fmt.bool out b
|
|
| App_fun (f, []) -> Fun.pp out f
|
|
| App_fun (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Fun.pp f (Util.pp_list Node.pp) l
|
|
| App_ho (f, []) -> Node.pp out f
|
|
| App_ho (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Node.pp f (Util.pp_list Node.pp) l
|
|
| Opaque t -> Node.pp out t
|
|
| Not u -> Fmt.fprintf out "(@[not@ %a@])" Node.pp u
|
|
| Eq (a,b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" Node.pp a Node.pp b
|
|
| If (a,b,c) -> Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" Node.pp a Node.pp b Node.pp c
|
|
end
|
|
|
|
module Sig_tbl = CCHashtbl.Make(Signature)
|
|
|
|
type t = {
|
|
mutable ok: bool; (* unsat? *)
|
|
tbl: node T_tbl.t;
|
|
sig_tbl: node Sig_tbl.t;
|
|
mutable combine: (node * node) list;
|
|
mutable pending: node list; (* refresh signature *)
|
|
true_: node;
|
|
false_: node;
|
|
}
|
|
|
|
let create tst : t =
|
|
let true_ = T.bool tst true in
|
|
let false_ = T.bool tst false in
|
|
let self = {
|
|
ok=true;
|
|
tbl= T_tbl.create 128;
|
|
sig_tbl=Sig_tbl.create 128;
|
|
combine=[];
|
|
pending=[];
|
|
true_=Node.make true_;
|
|
false_=Node.make false_;
|
|
} in
|
|
T_tbl.add self.tbl true_ self.true_;
|
|
T_tbl.add self.tbl false_ self.false_;
|
|
self
|
|
|
|
let clear (self:t) : unit =
|
|
let {ok=_; tbl; sig_tbl; pending=_; combine=_; true_; false_} = self in
|
|
self.ok <- true;
|
|
self.pending <- [];
|
|
self.combine <- [];
|
|
T_tbl.clear tbl;
|
|
Sig_tbl.clear sig_tbl;
|
|
T_tbl.add tbl true_.n_t true_;
|
|
T_tbl.add tbl false_.n_t false_;
|
|
()
|
|
|
|
let sub_ t k : unit =
|
|
match A.cc_view t with
|
|
| Bool _ | Opaque _ -> ()
|
|
| App_fun (_, args) -> args k
|
|
| App_ho (f, args) -> k f; args k
|
|
| Eq (a,b) -> k a; k b
|
|
| Not u -> k u
|
|
| If(a,b,c) -> k a; k b; k c
|
|
|
|
let rec add_t (self:t) (t:term) : node =
|
|
match T_tbl.find self.tbl t with
|
|
| n -> n
|
|
| exception Not_found ->
|
|
let node = Node.make t in
|
|
T_tbl.add self.tbl t node;
|
|
(* add sub-terms, and add [t] to their parent list *)
|
|
sub_ t
|
|
(fun u ->
|
|
let n_u = Node.root @@ add_t self u in
|
|
Node.add_parent n_u ~p:node);
|
|
(* need to compute signature *)
|
|
self.pending <- node :: self.pending;
|
|
node
|
|
|
|
let find_t_ (self:t) (t:term): node =
|
|
try T_tbl.find self.tbl t |> Node.root
|
|
with Not_found -> Error.errorf "mini-cc.find_t: no node for %a" T.pp t
|
|
|
|
exception E_unsat
|
|
|
|
let compute_sig (self:t) (n:node) : Signature.t option =
|
|
let[@inline] return x = Some x in
|
|
match A.cc_view n.n_t with
|
|
| Bool _ | Opaque _ -> None
|
|
| Eq (a,b) ->
|
|
let a = find_t_ self a in
|
|
let b = find_t_ self b in
|
|
return @@ Eq (a,b)
|
|
| Not u -> return @@ Not (find_t_ self u)
|
|
| App_fun (f, args) ->
|
|
let args = args |> Iter.map (find_t_ self) |> Iter.to_list in
|
|
if args<>[] then (
|
|
return @@ App_fun (f, args)
|
|
) else None
|
|
| App_ho (f, args) ->
|
|
let args = args |> Iter.map (find_t_ self) |> Iter.to_list in
|
|
return @@ App_ho (find_t_ self f, args)
|
|
| If (a,b,c) ->
|
|
return @@ If(find_t_ self a, find_t_ self b, find_t_ self c)
|
|
|
|
let update_sig_ (self:t) (n: node) : unit =
|
|
match compute_sig self n with
|
|
| None -> ()
|
|
| Some (Eq (a,b)) ->
|
|
if Node.equal a b then (
|
|
(* reduce to [true] *)
|
|
let n2 = self.true_ in
|
|
Log.debugf 5
|
|
(fun k->k "(@[mini-cc.congruence-by-eq@ %a@ %a@])" Node.pp n Node.pp n2);
|
|
self.combine <- (n,n2) :: self.combine;
|
|
)
|
|
| Some (Not u) when Node.equal u self.true_ ->
|
|
self.combine <- (n,self.false_) :: self.combine
|
|
| Some (Not u) when Node.equal u self.false_ ->
|
|
self.combine <- (n,self.true_) :: self.combine
|
|
| Some (If (a,b,_)) when Node.equal a self.true_ ->
|
|
self.combine <- (n,b) :: self.combine
|
|
| Some (If (a,_,c)) when Node.equal a self.false_ ->
|
|
self.combine <- (n,c) :: self.combine
|
|
| Some s ->
|
|
Log.debugf 5 (fun k->k "(@[mini-cc.update-sig@ %a@])" Signature.pp s);
|
|
match Sig_tbl.find self.sig_tbl s with
|
|
| n2 when Node.equal n n2 -> ()
|
|
| n2 ->
|
|
(* collision, merge *)
|
|
Log.debugf 5
|
|
(fun k->k "(@[mini-cc.congruence-by-sig@ %a@ %a@])" Node.pp n Node.pp n2);
|
|
self.combine <- (n,n2) :: self.combine;
|
|
| exception Not_found ->
|
|
Sig_tbl.add self.sig_tbl s n
|
|
|
|
let[@inline] is_bool self n = Node.equal self.true_ n || Node.equal self.false_ n
|
|
|
|
(* merge the two classes *)
|
|
let merge_ self n1 n2 : unit =
|
|
let n1 = Node.root n1 in
|
|
let n2 = Node.root n2 in
|
|
if not @@ Node.equal n1 n2 then (
|
|
(* merge into largest class, or into a boolean *)
|
|
let n1, n2 =
|
|
if is_bool self n1 then n1, n2
|
|
else if is_bool self n2 then n2, n1
|
|
else if Node.size n1 > Node.size n2 then n1, n2
|
|
else n2, n1 in
|
|
Log.debugf 5 (fun k->k "(@[mini-cc.merge@ :into %a@ %a@])" Node.pp n1 Node.pp n2);
|
|
|
|
if is_bool self n1 && is_bool self n2 then (
|
|
Log.debugf 5 (fun k->k "(mini-cc.conflict.merge-true-false)");
|
|
self.ok <- false;
|
|
raise E_unsat
|
|
);
|
|
|
|
self.pending <- List.rev_append n2.n_parents self.pending; (* will change signature *)
|
|
|
|
(* merge parent lists *)
|
|
n1.n_parents <- List.rev_append n2.n_parents n1.n_parents;
|
|
n1.n_size <- n2.n_size + n1.n_size;
|
|
|
|
(* update root pointer in [n2.class] *)
|
|
Node.iter_cls n2 (fun n -> n.n_root <- n1);
|
|
|
|
(* merge classes [next] pointers *)
|
|
let n1_next = n1.n_next in
|
|
n1.n_next <- n2.n_next;
|
|
n2.n_next <- n1_next;
|
|
)
|
|
|
|
let[@inline] check_ok_ self =
|
|
if not self.ok then raise_notrace E_unsat
|
|
|
|
(* fixpoint of the congruence closure *)
|
|
let fixpoint (self:t) : unit =
|
|
while not (CCList.is_empty self.pending && CCList.is_empty self.combine) do
|
|
check_ok_ self;
|
|
while not @@ CCList.is_empty self.pending do
|
|
let n = List.hd self.pending in
|
|
self.pending <- List.tl self.pending;
|
|
update_sig_ self n
|
|
done;
|
|
while not @@ CCList.is_empty self.combine do
|
|
let (n1,n2) = List.hd self.combine in
|
|
self.combine <- List.tl self.combine;
|
|
merge_ self n1 n2
|
|
done
|
|
done
|
|
|
|
(* API *)
|
|
|
|
let add_lit (self:t) (p:T.t) (sign:bool) : unit =
|
|
match A.cc_view p with
|
|
| Eq (t1,t2) when sign ->
|
|
let n1 = add_t self t1 in
|
|
let n2 = add_t self t2 in
|
|
self.combine <- (n1,n2) :: self.combine
|
|
| _ ->
|
|
(* just merge with true/false *)
|
|
let n = add_t self p in
|
|
let n2 = if sign then self.true_ else self.false_ in
|
|
self.combine <- (n,n2) :: self.combine
|
|
|
|
let check_sat (self:t) : bool =
|
|
try fixpoint self; true
|
|
with E_unsat ->
|
|
self.ok <- false;
|
|
false
|
|
|
|
let classes self : _ Iter.t =
|
|
T_tbl.values self.tbl
|
|
|> Iter.filter Node.is_root
|
|
|> Iter.map
|
|
(fun n -> Node.iter_cls n |> Iter.map Node.term)
|
|
end
|