From 1d5c1c187c81666371eb19f3b271db7792b9648d Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Thu, 25 Jan 2018 23:32:36 -0600 Subject: [PATCH] wip: basic SMT infrastructure - basic types, including terms and nodes (internalized terms) - congruence closure - utils --- src/smt/Bag.ml | 67 ++++ src/smt/Bag.mli | 31 ++ src/smt/Congruence_closure.ml | 560 +++++++++++++++++++++++++++++++++ src/smt/Congruence_closure.mli | 66 ++++ src/smt/Cst.ml | 60 ++++ src/smt/Cst.mli | 22 ++ src/smt/Equiv_class.ml | 86 +++++ src/smt/Equiv_class.mli | 66 ++++ src/smt/Hash.ml | 39 +++ src/smt/Hash.mli | 24 ++ src/smt/IArray.ml | 155 +++++++++ src/smt/IArray.mli | 98 ++++++ src/smt/ID.ml | 40 +++ src/smt/ID.mli | 26 ++ src/smt/Intf.ml | 24 ++ src/smt/Lit.ml | 54 ++++ src/smt/Lit.mli | 29 ++ src/smt/Solver_types.ml | 274 ++++++++++++++++ src/smt/Term.ml | 194 ++++++++++++ src/smt/Term.mli | 74 +++++ src/smt/Term_cell.ml | 141 +++++++++ src/smt/Term_cell.mli | 38 +++ src/smt/Ty.ml | 88 ++++++ src/smt/Ty.mli | 32 ++ src/smt/Ty_card.ml | 19 ++ src/smt/Ty_card.mli | 19 ++ src/smt/Util.ml | 28 ++ src/smt/Util.mli | 17 + src/smt/jbuild | 9 + 29 files changed, 2380 insertions(+) create mode 100644 src/smt/Bag.ml create mode 100644 src/smt/Bag.mli create mode 100644 src/smt/Congruence_closure.ml create mode 100644 src/smt/Congruence_closure.mli create mode 100644 src/smt/Cst.ml create mode 100644 src/smt/Cst.mli create mode 100644 src/smt/Equiv_class.ml create mode 100644 src/smt/Equiv_class.mli create mode 100644 src/smt/Hash.ml create mode 100644 src/smt/Hash.mli create mode 100644 src/smt/IArray.ml create mode 100644 src/smt/IArray.mli create mode 100644 src/smt/ID.ml create mode 100644 src/smt/ID.mli create mode 100644 src/smt/Intf.ml create mode 100644 src/smt/Lit.ml create mode 100644 src/smt/Lit.mli create mode 100644 src/smt/Solver_types.ml create mode 100644 src/smt/Term.ml create mode 100644 src/smt/Term.mli create mode 100644 src/smt/Term_cell.ml create mode 100644 src/smt/Term_cell.mli create mode 100644 src/smt/Ty.ml create mode 100644 src/smt/Ty.mli create mode 100644 src/smt/Ty_card.ml create mode 100644 src/smt/Ty_card.mli create mode 100644 src/smt/Util.ml create mode 100644 src/smt/Util.mli create mode 100644 src/smt/jbuild diff --git a/src/smt/Bag.ml b/src/smt/Bag.ml new file mode 100644 index 00000000..bc5f41a3 --- /dev/null +++ b/src/smt/Bag.ml @@ -0,0 +1,67 @@ + +(* This file is free software. See file "license" for more details. *) + +(** {1 Ordered Bag of Elements} + + A data structure where we can have duplicate elements, optimized for + fast concatenation and size. *) + +type 'a t = + | E + | L of 'a + | N of 'a t * 'a t * int (* size *) + +let empty = E + +let is_empty = function + | E -> true + | L _ | N _ -> false + +let size = function + | E -> 0 + | L _ -> 1 + | N (_,_,sz) -> sz + +let return x = L x + +let append a b = match a, b with + | E, _ -> b + | _, E -> a + | _ -> N (a, b, size a + size b) + +let cons x t = match t with + | E -> L x + | L _ -> N (L x, t, 2) + | N (_,_,sz) -> N (L x, t, sz+1) + +let rec fold f acc = function + | E -> acc + | L x -> f acc x + | N (a,b,_) -> fold f (fold f acc a) b + +let rec to_seq t yield = match t with + | E -> () + | L x -> yield x + | N (a,b,_) -> to_seq a yield; to_seq b yield + +let iter f t = to_seq t f + +let equal f a b = + let rec push x l = match x with + | E -> l + | L _ -> x :: l + | N (a,b,_) -> push a (b::l) + in + (* same-fringe traversal, using two stacks *) + let rec aux la lb = match la, lb with + | [], [] -> true + | E::_, _ | _, E::_ -> assert false + | N (x,y,_)::la, _ -> aux (push x (y::la)) lb + | _, N(x,y,_)::lb -> aux la (push x (y::lb)) + | L x :: la, L y :: lb -> f x y && aux la lb + | [], L _::_ + | L _::_, [] -> false + in + size a = size b && + aux (push a []) (push b []) + diff --git a/src/smt/Bag.mli b/src/smt/Bag.mli new file mode 100644 index 00000000..ad3585b8 --- /dev/null +++ b/src/smt/Bag.mli @@ -0,0 +1,31 @@ + +(* This file is free software. See file "license" for more details. *) + +(** {1 Ordered Bag of Elements} + + A data structure where we can have duplicate elements, optimized for + fast concatenation and size. *) + +type +'a t + +val empty : 'a t + +val is_empty : _ t -> bool + +val return : 'a -> 'a t + +val size : _ t -> int +(** Constant time *) + +val cons : 'a -> 'a t -> 'a t + +val append : 'a t -> 'a t -> 'a t + +val to_seq : 'a t -> 'a Sequence.t + +val fold : ('a -> 'b -> 'a) -> 'a -> 'b t -> 'a + +val iter : ('a -> unit) -> 'a t -> unit + +val equal : ('a -> 'a -> bool) -> 'a t -> 'a t -> bool +(** Are the two bags equal, element wise? *) diff --git a/src/smt/Congruence_closure.ml b/src/smt/Congruence_closure.ml new file mode 100644 index 00000000..0dcde83f --- /dev/null +++ b/src/smt/Congruence_closure.ml @@ -0,0 +1,560 @@ + +open CDCL +open Solver_types + +type node = Equiv_class.t +type repr = Equiv_class.repr + +(** A signature is a shallow term shape where immediate subterms + are representative *) +module Signature = struct + type t = node term_cell + include Term_cell.Make_eq(Equiv_class) +end + +module Sig_tbl = CCHashtbl.Make(Signature) + +type merge_op = node * node * cc_explanation +(* a merge operation to perform *) + +type actions = + | Propagate of Lit.t * cc_explanation list + | Split of Lit.t list * cc_explanation list + | Merge of node * node (* merge these two classes *) + +type t = { + tst: Term.state; + tbl: node Term.Tbl.t; + (* internalization [term -> node] *) + signatures_tbl : repr Sig_tbl.t; + (* map a signature to the corresponding term in some equivalence class. + A signature is a [term_cell] in which every immediate subterm + that participates in the congruence/evaluation relation + is normalized (i.e. is its own representative). + The critical property is that all members of an equivalence class + that have the same "shape" (including head symbol) + have the same signature *) + on_backtrack: (unit -> unit) -> unit; + (* register a function to be called when we backtrack *) + at_lvl_0: unit -> bool; + (* currently at level 0? *) + on_merge: (repr -> repr -> cc_explanation -> unit) list; + (* callbacks to call when we merge classes *) + pending: node Vec.t; + (* nodes to check, maybe their new signature is in {!signatures_tbl} *) + combine: merge_op Vec.t; + (* pairs of terms to merge *) + mutable actions : actions list; + (* some boolean propagations/splits to make. *) + mutable ps_lits: Lit.Set.t; + (* proof state *) + ps_queue: (node*node) Vec.t; + (* pairs to explain *) + true_ : node lazy_t; + false_ : node lazy_t; +} +(* TODO: an additional union-find to keep track, for each term, + of the terms they are known to be equal to, according + to the current explanation. That allows not to prove some equality + several times. + See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *) + +module CC_expl_set = CCSet.Make(struct + type t = cc_explanation + let compare = Solver_types.cmp_cc_expl + end) + +let[@inline] is_root_ (n:node) : bool = n.n_root == n + +let[@inline] size_ (r:repr) = + assert (is_root_ (r:>node)); + Bag.size (r :> node).n_parents + +(* check if [t] is in the congruence closure. + Invariant: [in_cc t => in_cc u, forall u subterm t] *) +let[@inline] mem (cc:t) (t:term): bool = + Term.Tbl.mem cc.tbl t + +(* find representative, recursively, and perform path compression *) +let rec find_rec cc (n:node) : repr = + if n==n.n_root then ( + Equiv_class.unsafe_repr_of_node n + ) else ( + let old_root = n.n_root in + let root = find_rec cc old_root in + (* path compression *) + if (root :> node) != old_root then ( + if not (cc.at_lvl_0 ()) then ( + cc.on_backtrack (fun () -> n.n_root <- old_root); + ); + n.n_root <- (root :> node); + ); + root + ) + +let[@inline] true_ cc = Lazy.force cc.true_ +let[@inline] false_ cc = Lazy.force cc.false_ + +(* get term that should be there *) +let[@inline] get_ cc (t:term) : node = + try Term.Tbl.find cc.tbl t + with Not_found -> + Log.debugf 5 (fun k->k "(@[missing@ %a@])" Term.pp t); + assert false + +(* non-recursive, inlinable function for [find] *) +let[@inline] find st (n:node) : repr = + if n == n.n_root + then (Equiv_class.unsafe_repr_of_node n) + else find_rec st n + +let[@inline] find_tn cc (t:term) : repr = get_ cc t |> find cc +let[@inline] find_tt cc (t:term) : term = find_tn cc t |> Equiv_class.Repr.term + +let[@inline] same_class cc (n1:node)(n2:node): bool = + Equiv_class.Repr.equal (find cc n1) (find cc n2) + +(* compute signature *) +let signature cc (t:term): node term_cell option = + let find = (find_tn cc :> term -> node) in + begin match Term.cell t with + | True | Builtin _ + -> None + | App_cst (_, a) when IArray.is_empty a -> None + | App_cst (f, a) -> App_cst (f, IArray.map find a) |> CCOpt.return + | If (a,b,c) -> If (find a, get_ cc b, get_ cc c) |> CCOpt.return + | Case (t, m) -> Case (find t, ID.Map.map (get_ cc) m) |> CCOpt.return + end + +(* find whether the given (parent) term corresponds to some signature + in [signatures_] *) +let find_by_signature cc (t:term) : repr option = match signature cc t with + | None -> None + | Some s -> Sig_tbl.get cc.signatures_tbl s + +let remove_signature cc (t:term): unit = match signature cc t with + | None -> () + | Some s -> + Sig_tbl.remove cc.signatures_tbl s + +let add_signature cc (t:term) (r:repr): unit = match signature cc t with + | None -> () + | Some s -> + assert (CCOpt.map_or ~default:false (Signature.equal s) + (signature cc (r:>node).n_term)); + (* add, but only if not present already *) + begin match Sig_tbl.get cc.signatures_tbl s with + | None -> + if not (cc.at_lvl_0 ()) then ( + cc.on_backtrack + (fun () -> Sig_tbl.remove cc.signatures_tbl s); + ); + Sig_tbl.add cc.signatures_tbl s r; + | Some r' -> + assert (Equiv_class.Repr.equal r r'); + end + +let is_done (cc:t): bool = + Vec.is_empty cc.pending && + Vec.is_empty cc.combine + +let push_pending cc t : unit = + Log.debugf 5 (fun k->k "(@[push_pending@ %a@])" Equiv_class.pp t); + Vec.push cc.pending t + +let push_combine cc t u e : unit = + Log.debugf 5 + (fun k->k "(@[push_combine@ %a@ %a@ expl: %a@])" + Equiv_class.pp t Equiv_class.pp u pp_cc_explanation e); + Vec.push cc.combine (t,u,e) + +let push_split cc (lits:lit list) (expl:cc_explanation list): unit = + Log.debugf 5 + (fun k->k "(@[push_split@ (@[%a@])@ expl: (@[%a@])@])" + (Util.pp_list Lit.pp) lits (Util.pp_list pp_cc_explanation) expl); + let l = Split (lits, expl) in + cc.actions <- l :: cc.actions + +let push_propagation cc (lit:lit) (expl:cc_explanation list): unit = + Log.debugf 5 + (fun k->k "(@[push_propagate@ %a@ expl: (@[%a@])@])" + Lit.pp lit (Util.pp_list pp_cc_explanation) expl); + let l = Propagate (lit,expl) in + cc.actions <- l :: cc.actions + +let[@inline] union cc (a:node) (b:node) (e:cc_explanation): unit = + if not (same_class cc a b) then ( + push_combine cc a b e; (* start by merging [a=b] *) + ) + +(* re-root the explanation tree of the equivalence class of [n] + so that it points to [n]. + postcondition: [n.n_expl = None] *) +let rec reroot_expl cc (n:node): unit = + let old_expl = n.n_expl in + if not (cc.at_lvl_0 ()) then ( + cc.on_backtrack (fun () -> n.n_expl <- old_expl); + ); + begin match old_expl with + | None -> () (* already root *) + | Some (u, e_n_u) -> + reroot_expl cc u; + u.n_expl <- Some (n, e_n_u); + n.n_expl <- None; + end + +(* TODO: + - move what follows into {!Theory}. + - also, obtain merges of CC via callbacks / [pop_merges] afterwards? + *) + +exception Exn_unsat of cc_explanation list + +let unsat (e:cc_explanation list): _ = raise (Exn_unsat e) + +type result = + | Sat of actions list + | Unsat of cc_explanation list + (* list of direct explanations to the conflict. *) + +let[@inline] all_classes cc : repr Sequence.t = + Term.Tbl.values cc.tbl + |> Sequence.filter is_root_ + |> Equiv_class.unsafe_repr_seq_of_seq + +(* main CC algo: add terms from [pending] to the signature table, + check for collisions *) +let rec update_pending (cc:t): result = + (* step 2 deal with pending (parent) terms whose equiv class + might have changed *) + while not (Vec.is_empty cc.pending) do + let n = Vec.pop_last cc.pending in + (* check if some parent collided *) + begin match find_by_signature cc n.n_term with + | None -> + (* add to the signature table [n --> n.root] *) + add_signature cc n.n_term (find cc n) + | Some u -> + (* must combine [t] with [r] *) + push_combine cc n (u:>node) (CC_congruence (n,(u:>node))) + end; + (* FIXME: when to actually evaluate? + eval_pending cc; + *) + done; + if is_done cc then ( + let actions = cc.actions in + cc.actions <- []; + Sat actions + ) else ( + update_combine cc (* repeat *) + ) + +(* main CC algo: merge equivalence classes in [st.combine]. + @raise Exn_unsat if merge fails *) +and update_combine cc = + while not (Vec.is_empty cc.combine) do + let a, b, e_ab = Vec.pop_last cc.combine in + let ra = find cc a in + let rb = find cc b in + if not (Equiv_class.Repr.equal ra rb) then ( + assert (is_root_ (ra:>node)); + assert (is_root_ (rb:>node)); + (* We will merge [r_from] into [r_into]. + we try to ensure that [size ra <= size rb] in general, unless + it clashes with the invariant that the representative must + be a normal form if the class contains a normal form *) + let r_from, r_into = + if size_ ra > size_ rb then rb, ra else ra, rb + in + (* remove [ra.parents] from signature, put them into [st.pending] *) + begin + Bag.to_seq (r_from:>node).n_parents + |> Sequence.iter + (fun parent -> + (* FIXME: with OCaml's hashtable, we should be able + to keep this entry (and have it become relevant later + once the signature of [parent] is backtracked) *) + remove_signature cc parent.n_term; + push_pending cc parent) + end; + (* perform [union ra rb] *) + begin + let r_from = (r_from :> node) in + let r_into = (r_into :> node) in + let rb_old_class = r_into.n_class in + let rb_old_parents = r_into.n_parents in + cc.on_backtrack + (fun () -> + r_from.n_root <- r_from; + r_into.n_class <- rb_old_class; + r_into.n_parents <- rb_old_parents); + r_from.n_root <- r_into; + r_from.n_class <- Bag.append rb_old_class r_from.n_class; + r_from.n_parents <- Bag.append rb_old_parents r_from.n_parents; + end; + (* update explanations (a -> b), arbitrarily *) + begin + reroot_expl cc a; + assert (a.n_expl = None); + if not (cc.at_lvl_0 ()) then ( + cc.on_backtrack (fun () -> a.n_expl <- None); + ); + a.n_expl <- Some (b, e_ab); + end; + (* notify listeners of the merge *) + notify_merge cc r_from ~into:r_into e_ab; + ) + done; + (* now update pending terms again *) + update_pending cc + +(* Checks if [ra] and [~into] have compatible normal forms and can + be merged w.r.t. the theories. + Side effect: also pushes sub-tasks *) +and notify_merge cc (ra:repr) ~into:(rb:repr) (e:cc_explanation): unit = + assert (is_root_ (ra:>node)); + assert (is_root_ (rb:>node)); + List.iter + (fun f -> f ra rb e) + cc.on_merge; + () + + +(* FIXME: callback? +(* evaluation rules: if, case... *) +and eval_pending (t:term): unit = + List.iter + (fun ((module Theory):repr theory) -> Theory.eval t) + theories + *) + +(* FIXME: remove? +(* main CC algo: add missing terms to the congruence class *) +and update_add (cc:t) terms () = + while not (Queue.is_empty cc.terms_to_add) do + let t = Queue.pop cc.terms_to_add in + add cc t + done +*) + +(* add [t] to [cc] when not present already *) +and add_new_term cc (t:term) : node = + assert (not @@ mem cc t); + let n = Equiv_class.make t in + (* how to add a subterm *) + let add_to_parents_of_sub_node (sub:node) : unit = + let old_parents = sub.n_parents in + if not @@ cc.at_lvl_0 () then ( + cc.on_backtrack (fun () -> sub.n_parents <- old_parents); + ); + sub.n_parents <- Bag.cons n sub.n_parents; + push_pending cc sub + in + (* add sub-term to [cc], and register [n] to its parents *) + let add_sub_t (u:term) : unit = + let n_u = add cc u in + add_to_parents_of_sub_node n_u + in + (* register sub-terms, add [t] to their parent list *) + begin match t.term_cell with + | True -> () + | App_cst (_, a) -> IArray.iter add_sub_t a + | If (a,b,c) -> + add_sub_t a; + add_sub_t b; + add_sub_t c + | Case (u, _) -> add_sub_t u + | Builtin b -> Term.builtin_to_seq b add_sub_t + end; + (* remove term when we backtrack *) + if not (cc.at_lvl_0 ()) then ( + cc.on_backtrack (fun () -> Term.Tbl.remove cc.tbl t); + ); + (* add term to the table *) + Term.Tbl.add cc.tbl t n; + (* [n] might be merged with other equiv classes *) + push_pending cc n; + n + +(* add [t=u] to the congruence closure, unconditionally (reduction relation) *) +and[@inline] add_eqn (cc:t) (eqn:merge_op): unit = + let t, u, expl = eqn in + push_combine cc t u expl + +(* add a term *) +and[@inline] add cc t = + try Term.Tbl.find cc.tbl t + with Not_found -> add_new_term cc t + +let[@inline] add_seq cc seq = seq (fun t -> ignore @@ add cc t) + +(* assert that this boolean literal holds *) +let assert_lit cc lit : unit = match Lit.view lit with + | Lit_fresh _ + | Lit_expanded _ -> () + | Lit_atom t -> + assert (Ty.is_prop t.term_ty); + let sign = Lit.sign lit in + (* equate t and true/false *) + let rhs = if sign then true_ cc else false_ cc in + let n = add cc t in + push_combine cc n rhs (CC_lit lit); + () + +let create ?(size=2048) ~on_backtrack ~at_lvl_0 ~on_merge (tst:Term.state) : t = + assert (at_lvl_0 ()); + let nd = Equiv_class.dummy in + let rec cc = { + tst; + tbl = Term.Tbl.create size; + on_merge; + signatures_tbl = Sig_tbl.create size; + on_backtrack; + at_lvl_0; + pending=Vec.make_empty Equiv_class.dummy; + combine= Vec.make_empty (nd,nd,CC_reduce_eq(nd,nd)); + actions=[]; + ps_lits=Lit.Set.empty; + ps_queue=Vec.make_empty (nd,nd); + true_ = lazy (add cc (Term.true_ tst)); + false_ = lazy (add cc (Term.false_ tst)); + } in + ignore (Lazy.force cc.true_); + ignore (Lazy.force cc.false_); + cc + +(* distance from [t] to its root in the proof forest *) +let[@inline][@unroll 2] rec distance_to_root (n:node): int = match n.n_expl with + | None -> 0 + | Some (t', _) -> 1 + distance_to_root t' + +(* find the closest common ancestor of [a] and [b] in the proof forest *) +let find_common_ancestor (a:node) (b:node) : node = + let d_a = distance_to_root a in + let d_b = distance_to_root b in + (* drop [n] nodes in the path from [t] to its root *) + let rec drop_ n t = + if n=0 then t + else match t.n_expl with + | None -> assert false + | Some (t', _) -> drop_ (n-1) t' + in + (* reduce to the problem where [a] and [b] have the same distance to root *) + let a, b = + if d_a > d_b then drop_ (d_a-d_b) a, b + else if d_a < d_b then a, drop_ (d_b-d_a) b + else a, b + in + (* traverse stepwise until a==b *) + let rec aux_same_dist a b = + if a==b then a + else match a.n_expl, b.n_expl with + | None, _ | _, None -> assert false + | Some (a', _), Some (b', _) -> aux_same_dist a' b' + in + aux_same_dist a b + +let[@inline] ps_add_obligation (cc:t) a b = Vec.push cc.ps_queue (a,b) +let[@inline] ps_add_lit ps l = ps.ps_lits <- Lit.Set.add l ps.ps_lits +let[@inline] ps_add_expl ps e = match e with + | CC_lit lit -> ps_add_lit ps lit + | CC_reduce_eq _ | CC_congruence _ + | CC_injectivity _ | CC_reduction + -> () + +and ps_add_obligation_t cc (t1:term) (t2:term) = + let n1 = get_ cc t1 in + let n2 = get_ cc t2 in + ps_add_obligation cc n1 n2 + +let ps_clear (cc:t) = + cc.ps_lits <- Lit.Set.empty; + Vec.clear cc.ps_queue; + () + +let decompose_explain cc (e:cc_explanation): unit = + Log.debugf 5 (fun k->k "(@[decompose_expl@ %a@])" pp_cc_explanation e); + ps_add_expl cc e; + begin match e with + | CC_reduction + | CC_lit _ -> () + | CC_reduce_eq (a, b) -> + ps_add_obligation cc a b; + | CC_injectivity (t1,t2) + (* FIXME: should this be different from CC_congruence? just explain why t1==t2? *) + | CC_congruence (t1,t2) -> + begin match t1.n_term.term_cell, t2.n_term.term_cell with + | True, _ -> assert false (* no congruence here *) + | App_cst (f1, a1), App_cst (f2, a2) -> + assert (Cst.equal f1 f2); + assert (IArray.length a1 = IArray.length a2); + IArray.iter2 (ps_add_obligation_t cc) a1 a2 + | Case (_t1, _m1), Case (_t2, _m2) -> + assert false + (* TODO: this should never happen + ps_add_obligation ps t1 t2; + ID.Map.iter + (fun id rhs1 -> + let rhs2 = ID.Map.find id m2 in + ps_add_obligation ps rhs1 rhs2) + m1; + *) + | If (a1,b1,c1), If (a2,b2,c2) -> + ps_add_obligation_t cc a1 a2; + ps_add_obligation_t cc b1 b2; + ps_add_obligation_t cc c1 c2; + | Builtin _, _ -> assert false + | App_cst _, _ + | Case _, _ + | If _, _ + -> assert false + end + end + +(* explain why [a = parent_a], where [a -> ... -> parent_a] in the + proof forest *) +let rec explain_along_path ps (a:node) (parent_a:node) : unit = + if a!=parent_a then ( + match a.n_expl with + | None -> assert false + | Some (next_a, e_a_b) -> + decompose_explain ps e_a_b; + (* now prove [next_a = parent_a] *) + explain_along_path ps next_a parent_a + ) + +(* find explanation *) +let explain_loop (cc : t) : Lit.Set.t = + while not (Vec.is_empty cc.ps_queue) do + let a, b = Vec.pop_last cc.ps_queue in + Log.debugf 5 + (fun k->k "(@[explain_loop at@ %a@ %a@])" Equiv_class.pp a Equiv_class.pp b); + assert (Equiv_class.Repr.equal (find cc a) (find cc b)); + let c = find_common_ancestor a b in + explain_along_path cc a c; + explain_along_path cc b c; + done; + cc.ps_lits + +let explain_unfold cc (l:cc_explanation list): Lit.Set.t = + Log.debugf 5 + (fun k->k "(@[explain_confict@ (@[%a@])@])" + (Util.pp_list pp_cc_explanation) l); + ps_clear cc; + List.iter (decompose_explain cc) l; + explain_loop cc + +let check_ cc = + try update_pending cc + with Exn_unsat e -> + Unsat e + +(* check satisfiability, update congruence closure *) +let check (cc:t) : result = + Log.debug 5 "(cc.check)"; + check_ cc + +let final_check cc : result = + Log.debug 5 "(CC.final_check)"; + check_ cc diff --git a/src/smt/Congruence_closure.mli b/src/smt/Congruence_closure.mli new file mode 100644 index 00000000..fc864895 --- /dev/null +++ b/src/smt/Congruence_closure.mli @@ -0,0 +1,66 @@ +(** {2 Congruence Closure} *) + +open CDCL +open Solver_types + +type t +(** Global state of the congruence closure *) + +type node = Equiv_class.t +(** Node in the congruence closure *) + +type repr = Equiv_class.repr +(** Node that is currently a representative *) + +val create : + ?size:int -> + on_backtrack:((unit -> unit) -> unit) -> + at_lvl_0:(unit -> bool) -> + on_merge:(repr -> repr -> cc_explanation -> unit) list -> + Term.state -> + t +(** Create a new congruence closure. + @param on_backtrack used to register undo actions + @param on_merge callbacks called when two equiv classes are merged +*) + +val find : t -> node -> repr +(** Current representative *) + +val same_class : t -> node -> node -> bool +(** Are these two classes the same in the current CC? *) + +val union : t -> node -> node -> cc_explanation -> unit +(** Merge the two equivalence classes. Will be undone on backtracking. *) + +val assert_lit : t -> Lit.t -> unit +(** Given a literal, assume it in the congruence closure and propagate + its consequences. Will be backtracked. *) + +val mem : t -> term -> bool +(** Is the term properly added to the congruence closure? *) + +val add : t -> term -> node +(** Add the term to the congruence closure, if not present already. + Will be backtracked. *) + +val add_seq : t -> term Sequence.t -> unit +(** Add a sequence of terms to the congruence closure *) + +type actions = + | Propagate of Lit.t * cc_explanation list + | Split of Lit.t list * cc_explanation list + | Merge of node * node (* merge these two classes *) + +type result = + | Sat of actions list + | Unsat of cc_explanation list + (* list of direct explanations to the conflict. *) + +val check : t -> result + +val final_check : t -> result + +val explain_unfold: t -> cc_explanation list -> Lit.Set.t +(** Unfold those explanations into a complete set of + literals implying them *) diff --git a/src/smt/Cst.ml b/src/smt/Cst.ml new file mode 100644 index 00000000..5f1f3856 --- /dev/null +++ b/src/smt/Cst.ml @@ -0,0 +1,60 @@ + +open CDCL +open Solver_types + +type t = cst + +let id t = t.cst_id + +let ty_of_kind = function + | Cst_defined (ty, _, _) + | Cst_undef ty + | Cst_test (ty, _) + | Cst_proj (ty, _, _) -> ty + | Cst_cstor (lazy cstor) -> cstor.cstor_ty + +let ty t = ty_of_kind t.cst_kind + +let arity t = fst (Ty.unfold_n (ty t)) + +let make cst_id cst_kind = {cst_id; cst_kind} +let make_cstor id ty cstor = + let _, ret = Ty.unfold ty in + assert (Ty.is_data ret); + make id (Cst_cstor cstor) +let make_proj id ty cstor i = + make id (Cst_proj (ty, cstor, i)) +let make_tester id ty cstor = + make id (Cst_test (ty, cstor)) + +let make_defined id ty t info = make id (Cst_defined (ty, t, info)) + +let make_undef id ty = make id (Cst_undef ty) + +let as_undefined (c:t) = match c.cst_kind with + | Cst_undef ty -> Some (c,ty) + | Cst_defined _ | Cst_cstor _ | Cst_proj _ | Cst_test _ + -> None + +let as_undefined_exn (c:t) = match as_undefined c with + | Some tup -> tup + | None -> assert false + +let is_finite_cstor c = match c.cst_kind with + | Cst_cstor (lazy {cstor_card=Finite; _}) -> true + | _ -> false + +let equal a b = ID.equal a.cst_id b.cst_id +let compare a b = ID.compare a.cst_id b.cst_id +let hash t = ID.hash t.cst_id +let pp out a = ID.pp out a.cst_id + +module Map = CCMap.Make(struct + type t = cst + let compare = compare + end) +module Tbl = CCHashtbl.Make(struct + type t = cst + let equal = equal + let hash = hash + end) diff --git a/src/smt/Cst.mli b/src/smt/Cst.mli new file mode 100644 index 00000000..b0e234f3 --- /dev/null +++ b/src/smt/Cst.mli @@ -0,0 +1,22 @@ + +open CDCL +open Solver_types + +type t = cst +val id : t -> ID.t +val ty : t -> Ty.t +val make_cstor : ID.t -> Ty.t -> data_cstor lazy_t -> t +val make_proj : ID.t -> Ty.t -> data_cstor lazy_t -> int -> t +val make_tester : ID.t -> Ty.t -> data_cstor lazy_t -> t +val make_defined : ID.t -> Ty.t -> term lazy_t -> cst_defined_info -> t +val make_undef : ID.t -> Ty.t -> t +val arity : t -> int (* number of args *) +val equal : t -> t -> bool +val compare : t -> t -> int +val hash : t -> int +val as_undefined : t -> (t * Ty.t) option +val as_undefined_exn : t -> t * Ty.t +val is_finite_cstor : t -> bool +val pp : t Fmt.printer +module Map : CCMap.S with type key = t +module Tbl : CCHashtbl.S with type key = t diff --git a/src/smt/Equiv_class.ml b/src/smt/Equiv_class.ml new file mode 100644 index 00000000..74fa1f7a --- /dev/null +++ b/src/smt/Equiv_class.ml @@ -0,0 +1,86 @@ + +open CDCL +open Solver_types + +type t = cc_node +type repr = t +type payload = cc_node_payload + +let field_expanded = Node_bits.mk_field () +let field_has_expansion_lit = Node_bits.mk_field () +let field_is_lit = Node_bits.mk_field () +let field_is_split = Node_bits.mk_field () +let field_add_level_0 = Node_bits.mk_field() +let () = Node_bits.freeze() + +let[@inline] equal (n1:t) n2 = n1==n2 +let[@inline] hash n = Term.hash n.n_term +let[@inline] term n = n.n_term +let[@inline] payload n = n.n_payload +let[@inline] pp out n = Term.pp out n.n_term + +module Repr = struct + type node = t + type t = repr + let equal = equal + let hash = hash + let term = term + let payload = payload + let pp = pp + + let[@inline] parents r = r.n_parents + let[@inline] class_ r = r.n_class +end + +let make (t:term) : t = + let rec n = { + n_term=t; + n_bits=Node_bits.empty; + n_class=Bag.empty; + n_parents=Bag.empty; + n_root=n; + n_expl=None; + n_payload=[]; + } in + (* set [class(t) = {t}] *) + n.n_class <- Bag.return n; + n + +let set_payload ?(can_erase=fun _->false) n e = + let rec aux = function + | [] -> [e] + | e' :: tail when can_erase e' -> e :: tail + | e' :: tail -> e' :: aux tail + in + n.n_payload <- aux n.n_payload + +let payload_find ~f:p n = + begin match n.n_payload with + | [] -> None + | e1 :: tail -> + match p e1, tail with + | Some _ as res, _ -> res + | None, [] -> None + | None, e2 :: tail2 -> + match p e2, tail2 with + | Some _ as res, _ -> res + | None, [] -> None + | None, e3 :: tail3 -> + match p e3 with + | Some _ as res -> res + | None -> CCList.find_map p tail3 + end + +let payload_pred ~f:p n = + begin match n.n_payload with + | [] -> false + | e :: _ when p e -> true + | _ :: e :: _ when p e -> true + | _ :: _ :: e :: _ when p e -> true + | l -> List.exists p l + end + + +let dummy = make Term.dummy +let[@inline] unsafe_repr_of_node n = n +let[@inline] unsafe_repr_seq_of_seq s = s diff --git a/src/smt/Equiv_class.mli b/src/smt/Equiv_class.mli new file mode 100644 index 00000000..cfae55cc --- /dev/null +++ b/src/smt/Equiv_class.mli @@ -0,0 +1,66 @@ + +open Solver_types + +type t = cc_node +type repr = private t +type payload = cc_node_payload + +val field_expanded : Node_bits.field +(** Term is expanded? *) + +val field_has_expansion_lit : Node_bits.field +(** Upon expansion, does this term have a special literal [Lit_expanded t] + that should be asserted? *) + +val field_is_lit : Node_bits.field +(** Is this term a boolean literal? *) + +val field_is_split : Node_bits.field +(** Did we perform case split (Split 1) on this term? + This is only relevant for terms whose type is a datatype. *) + +val field_add_level_0 : Node_bits.field +(** Is the corresponding term to be re-added upon backtracking, + down to level 0? *) + +(** {2 basics} *) + +val term : t -> term +val equal : t -> t -> bool +val hash : t -> int +val pp : t Fmt.printer +val payload : t -> payload list + +module Repr : sig + type node = t + type t = repr + + val term : t -> term + val equal : t -> t -> bool + val hash : t -> int + val pp : t Fmt.printer + val payload : t -> payload list + + val parents : t -> node Bag.t + val class_ : t -> node Bag.t +end + +(** {2 Helpers} *) + +val make : term -> t +(** Make a new equivalence class whose representative is the given term *) + +val payload_find: f:(payload -> 'a option) -> t -> 'a option + +val payload_pred: f:(payload -> bool) -> t -> bool + +val set_payload : ?can_erase:(payload -> bool) -> t -> payload -> unit +(** Add given payload + @param can_erase if provided, checks whether an existing value + is to be replaced instead of adding a new entry *) + +(**/**) +val dummy : t +val unsafe_repr_of_node : t -> repr +val unsafe_repr_seq_of_seq : t Sequence.t -> repr Sequence.t +(**/**) diff --git a/src/smt/Hash.ml b/src/smt/Hash.ml new file mode 100644 index 00000000..3d47a2b1 --- /dev/null +++ b/src/smt/Hash.ml @@ -0,0 +1,39 @@ + +(* This file is free software. See file "license" for more details. *) + +type 'a t = 'a -> int + +let bool b = if b then 1 else 2 + +let int i = i land max_int + +let string (s:string) = Hashtbl.hash s + +let combine f a b = Hashtbl.seeded_hash a (f b) + +let combine2 a b = Hashtbl.seeded_hash a b + +let combine3 a b c = + combine2 a b + |> combine2 c + +let combine4 a b c d = + combine2 a b + |> combine2 c + |> combine2 d + +let pair f g (x,y) = combine2 (f x) (g y) + +let opt f = function + | None -> 42 + | Some x -> combine2 43 (f x) + +let list f l = List.fold_left (combine f) 0x42 l +let array f = Array.fold_left (combine f) 0x43 +let iarray f = IArray.fold (combine f) 0x44 +let seq f seq = + let h = ref 0x43 in + seq (fun x -> h := combine f !h x); + !h + +let poly x = Hashtbl.hash x diff --git a/src/smt/Hash.mli b/src/smt/Hash.mli new file mode 100644 index 00000000..5e5a1dd2 --- /dev/null +++ b/src/smt/Hash.mli @@ -0,0 +1,24 @@ + +(* This file is free software. See file "license" for more details. *) + +type 'a t = 'a -> int + +val bool : bool t +val int : int t +val string : string t +val combine : 'a t -> int -> 'a -> int + +val pair : 'a t -> 'b t -> ('a * 'b) t + +val opt : 'a t -> 'a option t +val list : 'a t -> 'a list t +val array : 'a t -> 'a array t +val iarray : 'a t -> 'a IArray.t t +val seq : 'a t -> 'a Sequence.t t + +val combine2 : int -> int -> int +val combine3 : int -> int -> int -> int +val combine4 : int -> int -> int -> int -> int + +val poly : 'a t +(** the regular polymorphic hash function *) diff --git a/src/smt/IArray.ml b/src/smt/IArray.ml new file mode 100644 index 00000000..79196e56 --- /dev/null +++ b/src/smt/IArray.ml @@ -0,0 +1,155 @@ + +(* This file is free software. See file "license" for more details. *) + +type 'a t = 'a array + +let empty = [| |] + +let is_empty a = Array.length a = 0 + +let length = Array.length + +let singleton x = [| x |] + +let doubleton x y = [| x; y |] + +let make n x = Array.make n x + +let init n f = Array.init n f + +let get = Array.get + +let unsafe_get = Array.unsafe_get + +let set a n x = + let a' = Array.copy a in + a'.(n) <- x; + a' + +let map = Array.map + +let mapi = Array.mapi + +let append a b = + let na = length a in + Array.init (na + length b) + (fun i -> if i < na then a.(i) else b.(i-na)) + +let iter = Array.iter + +let iteri = Array.iteri + +let fold = Array.fold_left + +let foldi f acc a = + let n = ref 0 in + Array.fold_left + (fun acc x -> + let acc = f acc !n x in + incr n; + acc) + acc a + +exception ExitNow + +let for_all p a = + try + Array.iter (fun x -> if not (p x) then raise ExitNow) a; + true + with ExitNow -> false + +let exists p a = + try + Array.iter (fun x -> if p x then raise ExitNow) a; + false + with ExitNow -> true + +(** {2 Conversions} *) + +type 'a sequence = ('a -> unit) -> unit +type 'a gen = unit -> 'a option + +let of_list = Array.of_list + +let to_list = Array.to_list + +let of_array_unsafe a = a (* careful with that axe, Eugene *) + +let to_seq a k = iter k a + +let of_seq s = + let l = ref [] in + s (fun x -> l := x :: !l); + Array.of_list (List.rev !l) + +(*$Q + Q.(list int) (fun l -> \ + let g = Sequence.of_list l in \ + of_seq g |> to_seq |> Sequence.to_list = l) +*) + +let rec gen_to_list_ acc g = match g() with + | None -> List.rev acc + | Some x -> gen_to_list_ (x::acc) g + +let of_gen g = + let l = gen_to_list_ [] g in + Array.of_list l + +let to_gen a = + let i = ref 0 in + fun () -> + if !i < Array.length a then ( + let x = a.(!i) in + incr i; + Some x + ) else None + +(*$Q + Q.(list int) (fun l -> \ + let g = Gen.of_list l in \ + of_gen g |> to_gen |> Gen.to_list = l) +*) + +(** {2 IO} *) + +type 'a printer = Format.formatter -> 'a -> unit + +let print ?(start="[|") ?(stop="|]") ?(sep=";") pp_item out a = + Format.pp_print_string out start; + for k = 0 to Array.length a - 1 do + if k > 0 then ( + Format.pp_print_string out sep; + Format.pp_print_cut out () + ); + pp_item out a.(k) + done; + Format.pp_print_string out stop; + () + +(** {2 Binary} *) + +let equal = CCArray.equal +let compare = CCArray.compare +let for_all2 = CCArray.for_all2 +let exists2 = CCArray.exists2 + +let map2 f a b = + if length a <> length b then invalid_arg "map2"; + init (length a) (fun i -> f (unsafe_get a i) (unsafe_get b i)) + +let iter2 f a b = + if length a <> length b then invalid_arg "iter2"; + for i = 0 to length a-1 do + f (unsafe_get a i) (unsafe_get b i) + done + +let fold2 f acc a b = + if length a <> length b then invalid_arg "fold2"; + let rec aux acc i = + if i=length a then acc + else + let acc = f acc (unsafe_get a i) (unsafe_get b i) in + aux acc (i+1) + in + aux acc 0 diff --git a/src/smt/IArray.mli b/src/smt/IArray.mli new file mode 100644 index 00000000..c1e92712 --- /dev/null +++ b/src/smt/IArray.mli @@ -0,0 +1,98 @@ + +(* This file is free software. See file "license" for more details. *) + +type 'a t +(** Array of values of type 'a. The underlying type really is + an array, but it will never be modified. + + It should be covariant but OCaml will not accept it. *) + +val empty : 'a t + +val is_empty : _ t -> bool + +val length : _ t -> int + +val singleton : 'a -> 'a t + +val doubleton : 'a -> 'a -> 'a t + +val make : int -> 'a -> 'a t +(** [make n x] makes an array of [n] times [x] *) + +val init : int -> (int -> 'a) -> 'a t +(** [init n f] makes the array [[| f 0; f 1; ... ; f (n-1) |]]. + @raise Invalid_argument if [n < 0] *) + +val get : 'a t -> int -> 'a +(** Access the element *) + +val unsafe_get : 'a t -> int -> 'a +(** Unsafe access, not bound-checked. Use with caution *) + +val set : 'a t -> int -> 'a -> 'a t +(** Copy the array and modify its copy *) + +val map : ('a -> 'b) -> 'a t -> 'b t + +val mapi : (int -> 'a -> 'b) -> 'a t -> 'b t + +val append : 'a t -> 'a t -> 'a t + +val iter : ('a -> unit) -> 'a t -> unit + +val iteri : (int -> 'a -> unit) -> 'a t -> unit + +val foldi : ('a -> int -> 'b -> 'a) -> 'a -> 'b t -> 'a + +val fold : ('a -> 'b -> 'a) -> 'a -> 'b t -> 'a + +val for_all : ('a -> bool) -> 'a t -> bool + +val exists : ('a -> bool) -> 'a t -> bool + +(** {2 Conversions} *) + +type 'a sequence = ('a -> unit) -> unit +type 'a gen = unit -> 'a option + +val of_list : 'a list -> 'a t + +val to_list : 'a t -> 'a list + +val of_array_unsafe : 'a array -> 'a t +(** Take ownership of the given array. Careful, the array must {b NOT} + be modified afterwards! *) + +val to_seq : 'a t -> 'a sequence + +val of_seq : 'a sequence -> 'a t + +val of_gen : 'a gen -> 'a t + +val to_gen : 'a t -> 'a gen + +(** {2 IO} *) + +type 'a printer = Format.formatter -> 'a -> unit + +val print : + ?start:string -> ?stop:string -> ?sep:string -> + 'a printer -> 'a t printer + +(** {2 Binary} *) + +val equal : ('a -> 'a -> bool) -> 'a t -> 'a t -> bool + +val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + +val for_all2 : ('a -> 'a -> bool) -> 'a t -> 'a t -> bool + +val exists2 : ('a -> 'a -> bool) -> 'a t -> 'a t -> bool + +val map2 : ('a -> 'b -> 'c) -> 'a t -> 'b t -> 'c t + +val fold2 : ('acc -> 'a -> 'b -> 'acc) -> 'acc -> 'a t -> 'b t -> 'acc + +val iter2 : ('a -> 'b -> unit) -> 'a t -> 'b t -> unit + diff --git a/src/smt/ID.ml b/src/smt/ID.ml new file mode 100644 index 00000000..8f928002 --- /dev/null +++ b/src/smt/ID.ml @@ -0,0 +1,40 @@ + +(* This file is free software. See file "license" for more details. *) + +type t = { + id: int; + name: string; +} + +let make = + let n = ref 0 in + fun name -> + let x = { id= !n; name; } in + incr n; + x + +let makef fmt = CCFormat.ksprintf ~f:make fmt + +let copy {name;_} = make name + +let id {id;_} = id +let to_string id = id.name + +let equal a b = a.id=b.id +let compare a b = CCInt.compare a.id b.id +let hash a = CCHash.int a.id +let pp out a = Format.fprintf out "%s/%d" a.name a.id +let pp_name out a = CCFormat.string out a.name +let to_string_full a = Printf.sprintf "%s/%d" a.name a.id + +module AsKey = struct + type t_ = t + type t = t_ + let equal = equal + let compare = compare + let hash = hash +end + +module Map = CCMap.Make(AsKey) +module Set = CCSet.Make(AsKey) +module Tbl = CCHashtbl.Make(AsKey) diff --git a/src/smt/ID.mli b/src/smt/ID.mli new file mode 100644 index 00000000..494d13fc --- /dev/null +++ b/src/smt/ID.mli @@ -0,0 +1,26 @@ + +(* This file is free software. See file "license" for more details. *) + +(** {1 Unique Identifiers} *) + +type t + +val make : string -> t +val makef : ('a, Format.formatter, unit, t) format4 -> 'a +val copy : t -> t + +val id : t -> int + +val to_string : t -> string +val to_string_full : t -> string + +include Intf.EQ with type t := t +include Intf.ORD with type t := t +include Intf.HASH with type t := t +include Intf.PRINT with type t := t + +val pp_name : t CCFormat.printer + +module Map : CCMap.S with type key = t +module Set : CCSet.S with type elt = t +module Tbl : CCHashtbl.S with type key = t diff --git a/src/smt/Intf.ml b/src/smt/Intf.ml new file mode 100644 index 00000000..d8eedace --- /dev/null +++ b/src/smt/Intf.ml @@ -0,0 +1,24 @@ + +(* This file is free software. See file "license" for more details. *) + +module type EQ = sig + type t + val equal : t -> t -> bool +end + +module type ORD = sig + type t + val compare : t -> t -> int +end + +module type HASH = sig + type t + val hash : t -> int +end + +module type PRINT = sig + type t + val pp : t CCFormat.printer +end + +type 'a printer = Format.formatter -> 'a -> unit diff --git a/src/smt/Lit.ml b/src/smt/Lit.ml new file mode 100644 index 00000000..1930112d --- /dev/null +++ b/src/smt/Lit.ml @@ -0,0 +1,54 @@ + +open CDCL +open Solver_types + +type t = lit + +let neg l = {l with lit_sign=not l.lit_sign} + +let sign t = t.lit_sign +let view (t:t): lit_view = t.lit_view + +let abs t: t = {t with lit_sign=true} + +let make ~sign v = {lit_sign=sign; lit_view=v} + +(* assume the ID is fresh *) +let fresh_with id = make ~sign:true (Lit_fresh id) + +(* fresh boolean literal *) +let fresh: unit -> t = + let n = ref 0 in + fun () -> + let id = ID.makef "#fresh_%d" !n in + incr n; + make ~sign:true (Lit_fresh id) + +let dummy = fresh() + +let atom ?(sign=true) (t:term) : t = + let t, sign' = Term.abs t in + let sign = if not sign' then not sign else sign in + make ~sign (Lit_atom t) + +let eq tst a b = atom ~sign:true (Term.eq tst a b) +let neq tst a b = atom ~sign:false (Term.eq tst a b) +let expanded t = make ~sign:true (Lit_expanded t) + +let cstor_test tst cstor t = atom ~sign:true (Term.cstor_test tst cstor t) + +let as_atom (lit:t) : (term * bool) option = match lit.lit_view with + | Lit_atom t -> Some (t, lit.lit_sign) + | _ -> None + +let hash = hash_lit +let compare = cmp_lit +let equal a b = compare a b = 0 +let pp = pp_lit +let print = pp + +let norm l = + if l.lit_sign then l, CDCL.Same_sign else neg l, CDCL.Negated + +module Set = CCSet.Make(struct type t = lit let compare=compare end) +module Tbl = CCHashtbl.Make(struct type t = lit let equal=equal let hash=hash end) diff --git a/src/smt/Lit.mli b/src/smt/Lit.mli new file mode 100644 index 00000000..ba178daf --- /dev/null +++ b/src/smt/Lit.mli @@ -0,0 +1,29 @@ +(** {2 Literals} *) + +open CDCL +open Solver_types + +type t = lit +val neg : t -> t +val abs : t -> t +val sign : t -> bool +val view : t -> lit_view +val as_atom : t -> (term * bool) option +val fresh_with : ID.t -> t +val fresh : unit -> t +val dummy : t +val atom : ?sign:bool -> term -> t +val cstor_test : data_cstor -> term -> t +val eq : Term.state -> term -> term -> t +val neq : Term.state -> term -> term -> t +val cstor_test : Term.state -> data_cstor -> term -> t +val expanded : term -> t +val hash : t -> int +val compare : t -> t -> int +val equal : t -> t -> bool +val print : t Fmt.printer +val pp : t Fmt.printer +val norm : t -> t * CDCL.negated +module Set : CCSet.S with type elt = t +module Tbl : CCHashtbl.S with type key = t + diff --git a/src/smt/Solver_types.ml b/src/smt/Solver_types.ml new file mode 100644 index 00000000..889d2149 --- /dev/null +++ b/src/smt/Solver_types.ml @@ -0,0 +1,274 @@ + +open CDCL + +module Fmt = CCFormat +module Node_bits = CCBitField.Make(struct end) + +(* for objects that are expanded on demand only *) +type 'a lazily_expanded = + | Lazy_some of 'a + | Lazy_none + +(* main term cell. *) +and term = { + mutable term_id: int; (* unique ID *) + mutable term_ty: ty; + term_cell: term term_cell; +} + +(* term shallow structure *) +and 'a term_cell = + | True + | App_cst of cst * 'a IArray.t (* full, first-order application *) + | If of 'a * 'a * 'a + | Case of 'a * 'a ID.Map.t (* check head constructor *) + | Builtin of 'a builtin + +and 'a builtin = + | B_not of 'a + | B_eq of 'a * 'a + | B_and of 'a * 'a + | B_or of 'a * 'a + | B_imply of 'a * 'a + +(** A node of the congruence closure. + An equivalence class is represented by its "root" element, + the representative. + + If there is a normal form in the congruence class, then the + representative is a normal form *) +and cc_node = { + n_term: term; + mutable n_bits: Node_bits.t; (* bitfield for various properties *) + mutable n_class: cc_node Bag.t; (* terms in the same equiv class *) + mutable n_parents: cc_node Bag.t; (* parent terms of the whole equiv class *) + mutable n_root: cc_node; (* representative of congruence class (itself if a representative) *) + mutable n_expl: (cc_node * cc_explanation) option; (* the rooted forest for explanations *) + mutable n_payload: cc_node_payload list; (* list of theory payloads *) +} + +(** Theory-extensible payloads *) +and cc_node_payload = .. + +(* atomic explanation in the congruence closure *) +and cc_explanation = + | CC_reduction (* by pure reduction, tautologically equal *) + | CC_lit of lit (* because of this literal *) + | CC_congruence of cc_node * cc_node (* same shape *) + | CC_injectivity of cc_node * cc_node (* arguments of those constructors *) + | CC_reduce_eq of cc_node * cc_node (* reduce because those are equal *) +(* TODO: theory expl *) + +(* boolean literal *) +and lit = { + lit_view: lit_view; + lit_sign: bool; +} + +and lit_view = + | Lit_fresh of ID.t (* fresh literals *) + | Lit_atom of term + | Lit_expanded of term (* expanded? used for recursive calls mostly *) + (* TODO: remove this, unfold on the fly *) + +and cst = { + cst_id: ID.t; + cst_kind: cst_kind; +} + +and cst_kind = + | Cst_undef of ty (* simple undefined constant *) + | Cst_cstor of data_cstor lazy_t + | Cst_proj of ty * data_cstor lazy_t * int (* [cstor, argument position] *) + | Cst_test of ty * data_cstor lazy_t (* test if [cstor] *) + | Cst_defined of ty * term lazy_t * cst_defined_info + +(* what kind of constant is that? *) +and cst_defined_info = + | Cst_recursive + | Cst_non_recursive + +(* this is a disjunction of sufficient conditions for the existence of + some meta (cst). Each condition is a literal *) +and cst_exist_conds = lit lazy_t list ref + +and 'a db_env = { + db_st: 'a option list; + db_size: int; +} + +(* Hashconsed type *) +and ty = { + mutable ty_id: int; + ty_cell: ty_cell; + ty_card: ty_card lazy_t; +} + +and ty_card = + | Finite + | Infinite + +and ty_def = + | Uninterpreted + | Data of datatype (* set of constructors *) + +and datatype = { + data_cstors: data_cstor ID.Map.t lazy_t; +} + +(* TODO: in cstor, add: + - for each selector, a special "magic" term for undefined, in + case the selector is ill-applied (Collapse 2) *) + +(* a constructor *) +and data_cstor = { + cstor_ty: ty; + cstor_args: ty IArray.t; (* argument types *) + cstor_proj: cst IArray.t lazy_t; (* projectors *) + cstor_test: cst lazy_t; (* tester *) + cstor_cst: cst; (* the cstor itself *) + cstor_card: ty_card; (* cardinality of the constructor('s args) *) +} + +and ty_cell = + | Prop + | Atomic of ID.t * ty_def + | Arrow of ty * ty + + +let[@inline] term_equal_ (a:term) b = a==b +let[@inline] term_hash_ a = a.term_id +let[@inline] term_cmp_ a b = CCInt.compare a.term_id b.term_id + +let cmp_lit a b = + let c = CCBool.compare a.lit_sign b.lit_sign in + if c<>0 then c + else ( + let int_of_cell_ = function + | Lit_fresh _ -> 0 + | Lit_atom _ -> 1 + | Lit_expanded _ -> 2 + in + match a.lit_view, b.lit_view with + | Lit_fresh i1, Lit_fresh i2 -> ID.compare i1 i2 + | Lit_atom t1, Lit_atom t2 -> term_cmp_ t1 t2 + | Lit_expanded t1, Lit_expanded t2 -> term_cmp_ t1 t2 + | Lit_fresh _, _ + | Lit_atom _, _ + | Lit_expanded _, _ + -> CCInt.compare (int_of_cell_ a.lit_view) (int_of_cell_ b.lit_view) + ) + +let cst_compare a b = ID.compare a.cst_id b.cst_id + +let hash_lit a = + let sign = a.lit_sign in + match a.lit_view with + | Lit_fresh i -> Hash.combine3 1 (Hash.bool sign) (ID.hash i) + | Lit_atom t -> Hash.combine3 2 (Hash.bool sign) (term_hash_ t) + | Lit_expanded t -> + Hash.combine3 3 (Hash.bool sign) (term_hash_ t) + +let cmp_cc_node a b = term_cmp_ a.n_term b.n_term + +let cmp_cc_expl a b = + let toint = function + | CC_congruence _ -> 0 | CC_lit _ -> 1 + | CC_reduction -> 2 | CC_injectivity _ -> 3 + | CC_reduce_eq _ -> 5 + in + begin match a, b with + | CC_congruence (t1,t2), CC_congruence (u1,u2) -> + CCOrd.(cmp_cc_node t1 u1 (cmp_cc_node, t2, u2)) + | CC_reduction, CC_reduction -> 0 + | CC_lit l1, CC_lit l2 -> cmp_lit l1 l2 + | CC_injectivity (t1,t2), CC_injectivity (u1,u2) -> + CCOrd.(cmp_cc_node t1 u1 (cmp_cc_node, t2, u2)) + | CC_reduce_eq (t1, u1), CC_reduce_eq (t2,u2) -> + CCOrd.(cmp_cc_node t1 t2 (cmp_cc_node, u1, u2)) + | CC_congruence _, _ | CC_lit _, _ | CC_reduction, _ + | CC_injectivity _, _ | CC_reduce_eq _, _ + -> CCInt.compare (toint a)(toint b) + end + +let pp_cst out a = ID.pp out a.cst_id +let id_of_cst a = a.cst_id + +let pp_db out (i,_) = Format.fprintf out "%%%d" i + +let ty_unfold ty : ty list * ty = + let rec aux acc ty = match ty.ty_cell with + | Arrow (a,b) -> aux (a::acc) b + | _ -> List.rev acc, ty + in + aux [] ty + +let rec pp_ty out t = match t.ty_cell with + | Prop -> Fmt.string out "prop" + | Atomic (id, _) -> ID.pp out id + | Arrow _ -> + let args, ret = ty_unfold t in + Format.fprintf out "(@[->@ %a@ %a@])" + (Util.pp_list pp_ty) args pp_ty ret + +let pp_term_top ~ids out t = + let rec pp out t = + pp_rec out t; + (* FIXME + if Config.pp_hashcons then Format.fprintf out "/%d" t.term_id; + *) + () + + and pp_rec out t = match t.term_cell with + | True -> Fmt.string out "true" + | App_cst (c, a) when IArray.is_empty a -> + pp_id out (id_of_cst c) + | App_cst (f,l) -> + Fmt.fprintf out "(@[<1>%a@ %a@])" pp_id (id_of_cst f) (Util.pp_iarray pp) l + | If (a, b, c) -> + Fmt.fprintf out "(@[if %a@ %a@ %a@])" pp a pp b pp c + | Case (t,m) -> + let pp_bind out (id,rhs) = + Fmt.fprintf out "(@[<1>case %a@ %a@])" pp_id id pp rhs + in + let print_map = + Fmt.seq ~sep:(Fmt.return "@ ") pp_bind + in + Fmt.fprintf out "(@[match %a@ (@[%a@])@])" + pp t print_map (ID.Map.to_seq m) + | Builtin (B_not t) -> Fmt.fprintf out "(@[not@ %a@])" pp t + | Builtin (B_and (a,b)) -> + Fmt.fprintf out "(@[and@ %a@ %a@])" pp a pp b + | Builtin (B_or (a,b)) -> + Fmt.fprintf out "(@[or@ %a@ %a@])" pp a pp b + | Builtin (B_imply (a,b)) -> + Fmt.fprintf out "(@[=>@ %a@ %a@])" pp a pp b + | Builtin (B_eq (a,b)) -> + Fmt.fprintf out "(@[=@ %a@ %a@])" pp a pp b + and pp_id = + if ids then ID.pp else ID.pp_name + in + pp out t + +let pp_term = pp_term_top ~ids:false + +let pp_lit out l = + let pp_lit_view out = function + | Lit_fresh i -> Format.fprintf out "#%a" ID.pp i + | Lit_atom t -> pp_term out t + | Lit_expanded t -> Format.fprintf out "(@[<1>expanded@ %a@])" pp_term t + in + if l.lit_sign then pp_lit_view out l.lit_view + else Format.fprintf out "(@[@<1>¬@ %a@])" pp_lit_view l.lit_view + +let pp_cc_node out n = pp_term out n.n_term + +let pp_cc_explanation out (e:cc_explanation) = match e with + | CC_reduction -> Fmt.string out "reduction" + | CC_lit lit -> pp_lit out lit + | CC_congruence (a,b) -> + Format.fprintf out "(@[congruence@ %a@ %a@])" pp_cc_node a pp_cc_node b + | CC_injectivity (a,b) -> + Format.fprintf out "(@[injectivity@ %a@ %a@])" pp_cc_node a pp_cc_node b + | CC_reduce_eq (t, u) -> + Format.fprintf out "(@[reduce_eq@ %a@ %a@])" pp_cc_node t pp_cc_node u diff --git a/src/smt/Term.ml b/src/smt/Term.ml new file mode 100644 index 00000000..01932453 --- /dev/null +++ b/src/smt/Term.ml @@ -0,0 +1,194 @@ + +open CDCL +open Solver_types + +type t = term + +let[@inline] id t = t.term_id +let[@inline] ty t = t.term_ty +let[@inline] cell t = t.term_cell + +let equal = term_equal_ +let hash = term_hash_ +let compare a b = CCInt.compare a.term_id b.term_id + +type state = { + tbl : term Term_cell.Tbl.t; + mutable n: int; + true_ : t lazy_t; + false_ : t lazy_t; +} + +let mk_real_ st c : t = + let term_ty = Term_cell.ty c in + let t = { + term_id= st.n; + term_ty; + term_cell=c; + } in + st.n <- 1 + st.n; + Term_cell.Tbl.add st.tbl c t; + t + +let[@inline] make st (c:t term_cell) : t = + try Term_cell.Tbl.find st.tbl c + with Not_found -> mk_real_ st c + +let[@inline] true_ st = Lazy.force st.true_ +let[@inline] false_ st = Lazy.force st.false_ + +let create ?(size=1024) () : state = + let rec st ={ + n=2; + tbl=Term_cell.Tbl.create size; + true_ = lazy (make st Term_cell.true_); + false_ = lazy (make st (Term_cell.not_ (true_ st))); + } in + ignore (Lazy.force st.true_); + ignore (Lazy.force st.false_); (* not true *) + st + +let[@inline] all_terms st = Term_cell.Tbl.values st.tbl + +let app_cst st f a = + let cell = Term_cell.app_cst f a in + make st cell + +let const st c = app_cst st c IArray.empty + +let case st u m = make st (Term_cell.case u m) + +let if_ st a b c = make st (Term_cell.if_ a b c) + +let not_ st t = make st (Term_cell.not_ t) + +let and_ st a b = make st (Term_cell.and_ a b) +let or_ st a b = make st (Term_cell.or_ a b) +let imply st a b = make st (Term_cell.imply a b) +let eq st a b = make st (Term_cell.eq a b) +let neq st a b = not_ st (eq st a b) +let builtin st b = make st (Term_cell.builtin b) + +(* "eager" and, evaluating [a] first *) +let and_eager st a b = if_ st a b (false_ st) + +let cstor_test st cstor t = make st (Term_cell.cstor_test cstor t) +let cstor_proj st cstor i t = make st (Term_cell.cstor_proj cstor i t) + +(* might need to tranfer the negation from [t] to [sign] *) +let abs t : t * bool = match t.term_cell with + | Builtin (B_not t) -> t, false + | _ -> t, true + +let rec and_l st = function + | [] -> true_ st + | [t] -> t + | a :: l -> and_ st a (and_l st l) + +let or_l st = function + | [] -> false_ st + | [t] -> t + | a :: l -> List.fold_left (or_ st) a l + +let fold_map_builtin + (f:'a -> term -> 'a * term) (acc:'a) (b:t builtin): 'a * t builtin = + let fold_binary acc a b = + let acc, a = f acc a in + let acc, b = f acc b in + acc, a, b + in + match b with + | B_not t -> + let acc, t' = f acc t in + acc, B_not t' + | B_and (a,b) -> + let acc, a, b = fold_binary acc a b in + acc, B_and (a,b) + | B_or (a,b) -> + let acc, a, b = fold_binary acc a b in + acc, B_or (a, b) + | B_eq (a,b) -> + let acc, a, b = fold_binary acc a b in + acc, B_eq (a, b) + | B_imply (a,b) -> + let acc, a, b = fold_binary acc a b in + acc, B_imply (a, b) + +let is_const t = match t.term_cell with + | App_cst (_, a) -> IArray.is_empty a + | _ -> false + +let map_builtin f b = + let (), b = fold_map_builtin (fun () t -> (), f t) () b in + b + +let builtin_to_seq b yield = match b with + | B_not t -> yield t + | B_or (a,b) + | B_imply (a,b) + | B_eq (a,b) -> yield a; yield b + | B_and (a,b) -> yield a; yield b + +module As_key = struct + type t = term + let compare = compare + let equal = equal + let hash = hash +end + +module Map = CCMap.Make(As_key) +module Tbl = CCHashtbl.Make(As_key) + +let to_seq t yield = + let rec aux t = + yield t; + match t.term_cell with + | True -> () + | App_cst (_,a) -> IArray.iter aux a + | If (a,b,c) -> aux a; aux b; aux c + | Case (t, m) -> + aux t; + ID.Map.iter (fun _ rhs -> aux rhs) m + | Builtin b -> builtin_to_seq b aux + in + aux t + +(* return [Some] iff the term is an undefined constant *) +let as_cst_undef (t:term): (cst * Ty.t) option = + match t.term_cell with + | App_cst (c, a) when IArray.is_empty a -> + Cst.as_undefined c + | _ -> None + +(* return [Some (cstor,ty,args)] if the term is a constructor + applied to some arguments *) +let as_cstor_app (t:term): (cst * data_cstor * term IArray.t) option = + match t.term_cell with + | App_cst ({cst_kind=Cst_cstor (lazy cstor); _} as c, a) -> + Some (c,cstor,a) + | _ -> None + +(* typical view for unification/equality *) +type unif_form = + | Unif_cst of cst * Ty.t + | Unif_cstor of cst * data_cstor * term IArray.t + | Unif_none + +let as_unif (t:term): unif_form = match t.term_cell with + | App_cst ({cst_kind=Cst_undef ty; _} as c, a) when IArray.is_empty a -> + Unif_cst (c,ty) + | App_cst ({cst_kind=Cst_cstor (lazy cstor); _} as c, a) -> + Unif_cstor (c,cstor,a) + | _ -> Unif_none + +let fpf = Format.fprintf + +let pp = Solver_types.pp_term + + + +let dummy : t = { + term_id= -1; + term_ty=Ty.prop; + term_cell=True; +} diff --git a/src/smt/Term.mli b/src/smt/Term.mli new file mode 100644 index 00000000..481d435b --- /dev/null +++ b/src/smt/Term.mli @@ -0,0 +1,74 @@ + +open CDCL +open Solver_types + +type t = term + +val id : t -> int +val cell : t -> term term_cell +val ty : t -> Ty.t +val equal : t -> t -> bool +val compare : t -> t -> int +val hash : t -> int + +type state + +val create : ?size:int -> unit -> state + +val true_ : state -> t +val false_ : state -> t +val const : state -> cst -> t +val app_cst : state -> cst -> t IArray.t -> t +val if_: state -> t -> t -> t -> t +val case : state -> t -> t ID.Map.t -> t +val builtin : state -> t builtin -> t +val and_ : state -> t -> t -> t +val or_ : state -> t -> t -> t +val not_ : state -> t -> t +val imply : state -> t -> t -> t +val eq : state -> t -> t -> t +val neq : state -> t -> t -> t +val and_eager : state -> t -> t -> t (* evaluate left argument first *) + +val cstor_test : state -> data_cstor -> term -> t +val cstor_proj : state -> data_cstor -> int -> term -> t + +val and_l : state -> t list -> t +val or_l : state -> t list -> t + +val abs : t -> t * bool + +val map_builtin : (t -> t) -> t builtin -> t builtin +val builtin_to_seq : t builtin -> t Sequence.t + +val to_seq : t -> t Sequence.t + +val all_terms : state -> t Sequence.t + +val pp : t Fmt.printer + +(** {6 Views} *) + +val is_const : t -> bool + +(* return [Some] iff the term is an undefined constant *) +val as_cst_undef : t -> (cst * Ty.t) option + +val as_cstor_app : t -> (cst * data_cstor * t IArray.t) option + +(* typical view for unification/equality *) +type unif_form = + | Unif_cst of cst * Ty.t + | Unif_cstor of cst * data_cstor * term IArray.t + | Unif_none + +val as_unif : t -> unif_form + +(** {6 Containers} *) + +module Tbl : CCHashtbl.S with type key = t +module Map : CCMap.S with type key = t + +(**/**) +val dummy : t +(**/**) diff --git a/src/smt/Term_cell.ml b/src/smt/Term_cell.ml new file mode 100644 index 00000000..d07fadde --- /dev/null +++ b/src/smt/Term_cell.ml @@ -0,0 +1,141 @@ + +open CDCL +open Solver_types + +(* TODO: normalization of {!term_cell} for use in signatures? *) + +type t = term term_cell + +module type ARG = sig + type t + val hash : t -> int + val equal : t -> t -> bool +end + +module Make_eq(A : ARG) = struct + let sub_hash = A.hash + let sub_eq = A.equal + + let hash (t:A.t term_cell) : int = match t with + | True -> 1 + | App_cst (f,l) -> + Hash.combine3 4 (Cst.hash f) (Hash.iarray sub_hash l) + | If (a,b,c) -> Hash.combine4 7 (sub_hash a) (sub_hash b) (sub_hash c) + | Case (u,m) -> + let hash_m = + Hash.seq (Hash.pair ID.hash sub_hash) (ID.Map.to_seq m) + in + Hash.combine3 8 (sub_hash u) hash_m + | Builtin (B_not a) -> Hash.combine2 20 (sub_hash a) + | Builtin (B_and (t1,t2)) -> Hash.combine3 21 (sub_hash t1) (sub_hash t2) + | Builtin (B_or (t1,t2)) -> Hash.combine3 22 (sub_hash t1) (sub_hash t2) + | Builtin (B_imply (t1,t2)) -> Hash.combine3 23 (sub_hash t1) (sub_hash t2) + | Builtin (B_eq (t1,t2)) -> Hash.combine3 24 (sub_hash t1) (sub_hash t2) + + (* equality that relies on physical equality of subterms *) + let equal (a:A.t term_cell) b : bool = match a, b with + | True, True -> true + | App_cst (f1, a1), App_cst (f2, a2) -> + Cst.equal f1 f2 && IArray.equal sub_eq a1 a2 + | If (a1,b1,c1), If (a2,b2,c2) -> + sub_eq a1 a2 && sub_eq b1 b2 && sub_eq c1 c2 + | Case (u1, m1), Case (u2, m2) -> + sub_eq u1 u2 && + ID.Map.for_all + (fun k1 rhs1 -> + try sub_eq rhs1 (ID.Map.find k1 m2) + with Not_found -> false) + m1 + && + ID.Map.for_all (fun k2 _ -> ID.Map.mem k2 m1) m2 + | Builtin b1, Builtin b2 -> + begin match b1, b2 with + | B_not a1, B_not a2 -> sub_eq a1 a2 + | B_and (a1,b1), B_and (a2,b2) + | B_or (a1,b1), B_or (a2,b2) + | B_eq (a1,b1), B_eq (a2,b2) + | B_imply (a1,b1), B_imply (a2,b2) -> sub_eq a1 a2 && sub_eq b1 b2 + | B_not _, _ | B_and _, _ | B_eq _, _ + | B_or _, _ | B_imply _, _ -> false + end + | True, _ + | App_cst _, _ + | If _, _ + | Case _, _ + | Builtin _, _ + -> false +end[@@inline] + +include Make_eq(struct + type t = term + let equal (t1:t) t2 = t1==t2 + let hash (t:term): int = t.term_id + end) + +let true_ = True + +let app_cst f a = App_cst (f, a) +let const c = App_cst (c, IArray.empty) + +let case u m = Case (u,m) +let if_ a b c = + assert (Ty.equal b.term_ty c.term_ty); + If (a,b,c) + +let cstor_test cstor t = + app_cst (Lazy.force cstor.cstor_test) (IArray.singleton t) + +let cstor_proj cstor i t = + let p = IArray.get (Lazy.force cstor.cstor_proj) i in + app_cst p (IArray.singleton t) + +let builtin b = + (* normalize a bit *) + let b = match b with + | B_eq (a,b) when a.term_id > b.term_id -> B_eq (b,a) + | B_and (a,b) when a.term_id > b.term_id -> B_and (b,a) + | B_or (a,b) when a.term_id > b.term_id -> B_or (b,a) + | _ -> b + in + Builtin b + +let not_ t = match t.term_cell with + | Builtin (B_not t') -> t'.term_cell + | _ -> builtin (B_not t) + +let and_ a b = builtin (B_and (a,b)) +let or_ a b = builtin (B_or (a,b)) +let imply a b = builtin (B_imply (a,b)) +let eq a b = builtin (B_eq (a,b)) + +(* type of an application *) +let rec app_ty_ ty l : Ty.t = match Ty.view ty, l with + | _, [] -> ty + | Arrow (ty_a,ty_rest), a::tail -> + assert (Ty.equal ty_a a.term_ty); + app_ty_ ty_rest tail + | (Prop | Atomic _), _::_ -> + assert false + +let ty (t:t): Ty.t = match t with + | True -> Ty.prop + | App_cst (f, a) -> + let n_args, ret = Cst.ty f |> Ty.unfold_n in + if n_args = IArray.length a + then ret (* fully applied *) + else ( + assert (IArray.length a < n_args); + app_ty_ (Cst.ty f) (IArray.to_list a) + ) + | If (_,b,_) -> b.term_ty + | Case (_,m) -> + let _, rhs = ID.Map.choose m in + rhs.term_ty + | Builtin _ -> Ty.prop + +module Tbl = CCHashtbl.Make(struct + type t = term term_cell + let equal = equal + let hash = hash + end) + diff --git a/src/smt/Term_cell.mli b/src/smt/Term_cell.mli new file mode 100644 index 00000000..bd2726ef --- /dev/null +++ b/src/smt/Term_cell.mli @@ -0,0 +1,38 @@ + +open CDCL +open Solver_types + +type t = term term_cell + +val equal : t -> t -> bool +val hash : t -> int + +val true_ : t +val const : cst -> t +val app_cst : cst -> term IArray.t -> t +val cstor_test : data_cstor -> term -> t +val cstor_proj : data_cstor -> int -> term -> t +val case : term -> term ID.Map.t -> t +val if_ : term -> term -> term -> t +val builtin : term builtin -> t +val and_ : term -> term -> t +val or_ : term -> term -> t +val not_ : term -> t +val imply : term -> term -> t +val eq : term -> term -> t + +val ty : t -> Ty.t +(** Compute the type of this term cell. Not totally free *) + +module Tbl : CCHashtbl.S with type key = t + +module type ARG = sig + type t + val hash : t -> int + val equal : t -> t -> bool +end + +module Make_eq(X : ARG) : sig + val equal : X.t term_cell -> X.t term_cell -> bool + val hash : X.t term_cell -> int +end diff --git a/src/smt/Ty.ml b/src/smt/Ty.ml new file mode 100644 index 00000000..319231f0 --- /dev/null +++ b/src/smt/Ty.ml @@ -0,0 +1,88 @@ + +open CDCL +open Solver_types + +type t = ty +type cell = ty_cell +type def = ty_def + +let view t = t.ty_cell + +let equal a b = a.ty_id = b.ty_id +let compare a b = CCInt.compare a.ty_id b.ty_id +let hash a = a.ty_id + +module Tbl_cell = CCHashtbl.Make(struct + type t = ty_cell + let equal a b = match a, b with + | Prop, Prop -> true + | Atomic (i1,_), Atomic (i2,_) -> ID.equal i1 i2 + | Arrow (a1,b1), Arrow (a2,b2) -> + equal a1 a2 && equal b1 b2 + | Prop, _ + | Atomic _, _ + | Arrow _, _ -> false + + let hash t = match t with + | Prop -> 1 + | Atomic (i,_) -> Hash.combine2 2 (ID.hash i) + | Arrow (a,b) -> Hash.combine3 3 (hash a) (hash b) + end) + +(* build a type *) +let make_ : ty_cell -> card:ty_card lazy_t -> t = + let tbl : t Tbl_cell.t = Tbl_cell.create 128 in + let n = ref 0 in + fun c ~card -> + try Tbl_cell.find tbl c + with Not_found -> + let ty_id = !n in + incr n; + let ty = {ty_id; ty_cell=c; ty_card=card; } in + Tbl_cell.add tbl c ty; + ty + +let prop = make_ Prop ~card:(Lazy.from_val Finite) + +let atomic id def ~card = make_ (Atomic (id,def)) ~card + +let arrow a b = + let card = lazy (Ty_card.(Lazy.force b.ty_card ^ Lazy.force a.ty_card)) in + make_ (Arrow (a,b)) ~card + +let arrow_l = List.fold_right arrow + +let is_prop t = + match t.ty_cell with | Prop -> true | _ -> false + +let is_data t = + match t.ty_cell with | Atomic (_, Data _) -> true | _ -> false + +let is_uninterpreted t = + match t.ty_cell with | Atomic (_, Uninterpreted) -> true | _ -> false + +let is_arrow t = + match t.ty_cell with | Arrow _ -> true | _ -> false + +let unfold = ty_unfold + +let unfold_n ty : int * t = + let rec aux n ty = match ty.ty_cell with + | Arrow (_,b) -> aux (n+1) b + | _ -> n, ty + in + aux 0 ty + +let pp = pp_ty + +(* representation as a single identifier *) +let rec mangle t : string = match t.ty_cell with + | Prop -> "prop" + | Atomic (id,_) -> ID.to_string id + | Arrow (a,b) -> mangle a ^ "_" ^ mangle b + +module Tbl = CCHashtbl.Make(struct + type t = ty + let equal = equal + let hash = hash + end) diff --git a/src/smt/Ty.mli b/src/smt/Ty.mli new file mode 100644 index 00000000..ae7f45a1 --- /dev/null +++ b/src/smt/Ty.mli @@ -0,0 +1,32 @@ + +(** {1 Hashconsed Types} *) + +open CDCL + +type t = Solver_types.ty +type cell = Solver_types.ty_cell +type def = Solver_types.ty_def + +val view : t -> cell + +val prop : t +val atomic : ID.t -> def -> card:Ty_card.t lazy_t -> t +val arrow : t -> t -> t +val arrow_l : t list -> t -> t + +val is_prop : t -> bool +val is_data : t -> bool +val is_uninterpreted : t -> bool +val is_arrow : t -> bool +val unfold : t -> t list * t +val unfold_n : t -> int * t + +val mangle : t -> string + +include Intf.EQ with type t := t +include Intf.ORD with type t := t +include Intf.HASH with type t := t +include Intf.PRINT with type t := t + +module Tbl : CCHashtbl.S with type key = t + diff --git a/src/smt/Ty_card.ml b/src/smt/Ty_card.ml new file mode 100644 index 00000000..a6e91975 --- /dev/null +++ b/src/smt/Ty_card.ml @@ -0,0 +1,19 @@ + +open Solver_types + +type t = ty_card + +let (+) a b = match a, b with Finite, Finite -> Finite | _ -> Infinite +let ( * ) a b = match a, b with Finite, Finite -> Finite | _ -> Infinite +let ( ^ ) a b = match a, b with Finite, Finite -> Finite | _ -> Infinite +let finite = Finite +let infinite = Infinite + +let sum = List.fold_left (+) Finite +let product = List.fold_left ( * ) Finite + +let equal = (=) +let compare = Pervasives.compare +let pp out = function + | Finite -> Fmt.string out "finite" + | Infinite -> Fmt.string out "infinite" diff --git a/src/smt/Ty_card.mli b/src/smt/Ty_card.mli new file mode 100644 index 00000000..07569e71 --- /dev/null +++ b/src/smt/Ty_card.mli @@ -0,0 +1,19 @@ + +(** {1 Type Cardinality} *) + +open CDCL + +type t = Solver_types.ty_card + +val (+) : t -> t -> t +val ( * ) : t -> t -> t +val ( ^ ) : t -> t -> t +val finite : t +val infinite : t + +val sum : t list -> t +val product : t list -> t + +val equal : t -> t -> bool +val compare : t -> t -> int +val pp : t Intf.printer diff --git a/src/smt/Util.ml b/src/smt/Util.ml new file mode 100644 index 00000000..576bebc8 --- /dev/null +++ b/src/smt/Util.ml @@ -0,0 +1,28 @@ + +(* This file is free software. See file "license" for more details. *) + +(** {1 Util} *) + +module Fmt = CCFormat + +type 'a printer = 'a CCFormat.printer + +let pp_sep sep out () = Format.fprintf out "%s@," sep + +let pp_list ?(sep=" ") pp out l = + Fmt.list ~sep:(pp_sep sep) pp out l + +let pp_array ?(sep=" ") pp out l = + Fmt.array ~sep:(pp_sep sep) pp out l + +let pp_iarray ?(sep=" ") pp out a = + Fmt.seq ~sep:(pp_sep sep) pp out (IArray.to_seq a) + +exception Error of string + +let () = Printexc.register_printer + (function + | Error msg -> Some ("internal error: " ^ msg) + | _ -> None) + +let errorf msg = Fmt.ksprintf msg ~f:(fun s -> raise (Error s)) diff --git a/src/smt/Util.mli b/src/smt/Util.mli new file mode 100644 index 00000000..c6f77edf --- /dev/null +++ b/src/smt/Util.mli @@ -0,0 +1,17 @@ + +(* This file is free software. See file "license" for more details. *) + +(** {1 Utils} *) + +type 'a printer = 'a CCFormat.printer + +val pp_list : ?sep:string -> 'a printer -> 'a list printer + +val pp_array : ?sep:string -> 'a printer -> 'a array printer + +val pp_iarray : ?sep:string -> 'a CCFormat.printer -> 'a IArray.t CCFormat.printer + +exception Error of string + +val errorf : ('a, Format.formatter, unit, 'b) format4 -> 'a +(** @raise Error when called *) diff --git a/src/smt/jbuild b/src/smt/jbuild new file mode 100644 index 00000000..137ba092 --- /dev/null +++ b/src/smt/jbuild @@ -0,0 +1,9 @@ +; vim:ft=lisp: + +(library + ((name CDCL_smt) + (public_name cdcl.smt) + (libraries (containers containers.data sequence cdcl)) + (flags (:standard -w +a-4-44-58-60@8 -color always -safe-string -short-paths)) + (ocamlopt_flags (:standard -O3 -color always + -unbox-closures -unbox-closures-factor 20))))