diff --git a/src/smt/Mini_cc.ml b/src/smt/Mini_cc.ml index 1882a686..80250408 100644 --- a/src/smt/Mini_cc.ml +++ b/src/smt/Mini_cc.ml @@ -1,9 +1,9 @@ module H = CCHash -type ('f, 't) view = ('f, 't) Mini_cc_intf.view = +type ('f, 't, 'ts) view = ('f, 't, 'ts) Mini_cc_intf.view = | Bool of bool - | App of 'f * 't list + | App of 'f * 'ts | If of 't * 't * 't type res = Mini_cc_intf.res = @@ -29,7 +29,7 @@ module Make(A: ARG) = struct mutable n_root: node; } - type signature = (fun_, node) view + type signature = (fun_, node, node list) view module Node = struct type t = node @@ -106,8 +106,8 @@ module Make(A: ARG) = struct let sub_ t k : unit = match T.view t with - | Bool _ | App (_, []) -> () - | App (_, l) -> List.iter k l + | Bool _ -> () + | App (_, args) -> args k | If(a,b,c) -> k a; k b; k c let rec add_t (self:t) (t:term) : node = @@ -166,8 +166,10 @@ module Make(A: ARG) = struct Sig_tbl.add self.sig_tbl s n in match T.view n.n_t with - | Bool _ | App (_, []) -> () - | App (f, l) -> aux @@ App (f, List.map (find_t_ self) l) + | Bool _ -> () + | App (f, args) -> + let args = args |> Sequence.map (find_t_ self) |> Sequence.to_list in + aux @@ App (f, args) | If (a,b,c) -> aux @@ If(find_t_ self a, find_t_ self b, find_t_ self c) (* merge the two classes *) diff --git a/src/smt/Mini_cc.mli b/src/smt/Mini_cc.mli index a7ee562d..69359b30 100644 --- a/src/smt/Mini_cc.mli +++ b/src/smt/Mini_cc.mli @@ -1,9 +1,9 @@ (** {1 Mini congruence closure} *) -type ('f, 't) view = ('f, 't) Mini_cc_intf.view = +type ('f, 't, 'ts) view = ('f, 't, 'ts) Mini_cc_intf.view = | Bool of bool - | App of 'f * 't list + | App of 'f * 'ts | If of 't * 't * 't type res = Mini_cc_intf.res = diff --git a/src/smt/Mini_cc_intf.ml b/src/smt/Mini_cc_intf.ml index 1a3d2832..52dff3c7 100644 --- a/src/smt/Mini_cc_intf.ml +++ b/src/smt/Mini_cc_intf.ml @@ -1,9 +1,13 @@ -type ('f, 't) view = +type ('f, 't, 'ts) view = | Bool of bool - | App of 'f * 't list + | App of 'f * 'ts | If of 't * 't * 't +(* TODO: also HO app, Eq, Distinct cases? + -> then API that just adds boolean terms and does the right thing in case of + Eq/Distinct *) + type res = | Sat | Unsat @@ -20,8 +24,10 @@ module type ARG = sig type t val equal : t -> t -> bool val hash : t -> int - val view : t -> (Fun.t, t) view val pp : t Fmt.printer + + (** View the term through the lens of the congruence closure *) + val view : t -> (Fun.t, t, t Sequence.t) view end end diff --git a/src/smt/Term.ml b/src/smt/Term.ml index 4dafbb5a..a4ebead7 100644 --- a/src/smt/Term.ml +++ b/src/smt/Term.ml @@ -86,6 +86,13 @@ let[@inline] is_const t = match view t with | App_cst (_, a) -> IArray.is_empty a | _ -> false +let cc_view (t:t) = + let module C = Mini_cc in + match view t with + | Bool b -> C.Bool b + | App_cst (f,args) -> C.App (f, IArray.to_seq args) + | If (a,b,c) -> C.If (a,b,c) + module As_key = struct type t = term let compare = compare @@ -113,3 +120,4 @@ let as_cst_undef (t:term): (cst * Ty.Fun.t) option = | _ -> None let pp = Solver_types.pp_term + diff --git a/src/smt/Term.mli b/src/smt/Term.mli index a8a4feeb..38dc9a98 100644 --- a/src/smt/Term.mli +++ b/src/smt/Term.mli @@ -49,6 +49,8 @@ val is_true : t -> bool val is_false : t -> bool val is_const : t -> bool +val cc_view : t -> (cst,t,t Sequence.t) Mini_cc.view + (* return [Some] iff the term is an undefined constant *) val as_cst_undef : t -> (cst * Ty.Fun.t) option