diff --git a/src/th-distinct/Sidekick_th_distinct.ml b/src/th-distinct/Sidekick_th_distinct.ml new file mode 100644 index 00000000..91858e47 --- /dev/null +++ b/src/th-distinct/Sidekick_th_distinct.ml @@ -0,0 +1,206 @@ + +module Term = Sidekick_smt.Term +module Theory = Sidekick_smt.Theory + +module type ARG = sig + module T : sig + type t + type state + val pp : t Fmt.printer + val equal : t -> t -> bool + val hash : t -> int + val as_distinct : t -> t Sequence.t option + val mk_eq : state -> t -> t -> t + end + module Lit : sig + type t + val term : t -> T.t + val neg : t -> t + val sign : t -> bool + val compare : t -> t -> int + val atom : T.t -> t + val pp : t Fmt.printer + end +end + +module type S = sig + type term + type term_state + type lit + + type data + val key : (term, lit, data) Sidekick_cc.Key.t + val th : Sidekick_smt.Theory.t +end + +module Make(A : ARG with type Lit.t = Sidekick_smt.Lit.t + and type T.t = Sidekick_smt.Term.t + and type T.state = Sidekick_smt.Term.state) = struct + module T = A.T + module Lit = A.Lit + module IM = CCMap.Make(Lit) + + type term = T.t + type lit = A.Lit.t + type data = term IM.t (* "distinct" lit -> term appearing under it*) + + let key : (term,lit,data) Sidekick_cc.Key.t = + let merge m1 m2 = + IM.merge_safe m1 m2 + ~f:(fun _ pair -> match pair with + | `Left x | `Right x -> Some x + | `Both (x,_) -> Some x) + and eq = IM.equal T.equal + and pp out m = + Fmt.fprintf out + "{@[%a@]}" Fmt.(seq ~sep:(return ",@ ") @@ pair Lit.pp T.pp) (IM.to_seq m) + in + Sidekick_cc.Key.create + ~pp + ~name:"distinct" + ~merge ~eq () + + (* micro theory *) + module Micro(CC : Sidekick_cc.Congruence_closure.S + with type term = T.t + and type lit = Lit.t + and module Key = Sidekick_cc.Key) = struct + exception E_exit + + let on_merge cc n1 m1 n2 m2 expl12 = + try + let _i = + IM.merge + (fun lit o1 o2 -> + match o1, o2 with + | Some t1, Some t2 -> + (* conflict! two terms under the same "distinct" [lit] + are merged, where [lit = distinct(t1,t2,…)]. + The conflict is: + [lit, t1=n1, t2=n2, expl-merge(n1,n2) ==> false] + *) + assert (not @@ T.equal t1 t2); + let expl = CC.Expl.mk_list + [expl12; + CC.Expl.mk_lit lit; + CC.Expl.mk_merge n1 (CC.Theory.add_term cc t1); + CC.Expl.mk_merge n2 (CC.Theory.add_term cc t2); + ] in + CC.Theory.raise_conflict cc expl; + raise_notrace E_exit + | _ -> None) + m1 m2 + in + () + with E_exit -> () + + let on_new_term _ _ = None + + let m_th = + CC.Theory.make ~key ~on_merge ~on_new_term () + end + + module T_tbl = CCHashtbl.Make(T) + type st = { + tst: T.state; + expanded: unit T_tbl.t; (* negative "distinct" that have been case-split on *) + } + + let create tst : st = { expanded=T_tbl.create 12; tst; } + + let pp_c out c = Fmt.fprintf out "(@[%a@])" (Util.pp_list Lit.pp) c + + module CC = Sidekick_smt.CC + + let process_lit (st:st) (acts:Theory.actions) (lit:Lit.t) (lit_t:term) (subs:term Sequence.t) : unit = + let (module A) = acts in + Log.debugf 5 (fun k->k "(@[th_distinct.process@ %a@])" Lit.pp lit); + let add_axiom c = A.add_persistent_axiom c in + let cc = A.cc in + if Lit.sign lit then ( + (* assert [distinct subs], so we update the node of each [t in subs] + with [lit] *) + (* FIXME: detect if some subs are already equal *) + subs + (fun sub -> + let n = CC.Theory.add_term cc sub in + CC.Theory.add_data cc n key (IM.singleton lit sub)); + ) else if not @@ T_tbl.mem st.expanded lit_t then ( + (* add clause [distinct t1…tn ∨ ∨_{i,j>i} t_i=j] *) + T_tbl.add st.expanded lit_t (); + let l = Sequence.to_list subs in + let c = + Sequence.diagonal_l l + |> Sequence.map (fun (t,u) -> Lit.atom @@ T.mk_eq st.tst t u) + |> Sequence.to_rev_list + in + let c = Lit.neg lit :: c in + Log.debugf 5 (fun k->k "(@[tseitin.distinct.case-split@ %a@])" pp_c c); + add_axiom c + ) + + let partial_check st (acts:Theory.actions) lits : unit = + lits + (fun lit -> + let t = Lit.term lit in + match T.as_distinct t with + | None -> () + | Some subs -> process_lit st acts lit t subs) + + let th = + Sidekick_smt.Theory.make + ~name:"distinct" + ~partial_check + ~final_check:(fun _ _ _ -> ()) + ~create () +end + +module T = struct + open Sidekick_smt + open Sidekick_smt.Solver_types + module T = Term + + type t = Term.t + type terms = t IArray.t + let compare = Term.compare + let to_seq = IArray.to_seq + + let id_distinct = ID.make "distinct" + + let relevant _id _ _ = true + let get_ty _ _ = Ty.prop + let abs ~self _a = self, true + + let as_distinct t : _ option = + match T.view t with + | T.App_cst ({cst_id;_}, args) when ID.equal cst_id id_distinct -> + Some (IArray.to_seq args) + | _ -> None + + let eval args = + let module Value = Sidekick_smt.Value in + if + Sequence.diagonal (IArray.to_seq args) + |> Sequence.for_all (fun (x,y) -> not @@ Value.equal x y) + then Value.true_ + else Value.false_ + + let c_distinct = + {cst_id=id_distinct; + cst_view=Cst_def { + pp=None; abs; ty=get_ty; relevant; do_cc=true; eval; }; } + + let distinct st a = + if IArray.length a <= 1 + then T.true_ st + else T.app_cst st c_distinct a + + let distinct_l st = function + | [] | [_] -> T.true_ st + | xs -> distinct st (IArray.of_list xs) +end + +let distinct = T.distinct +let distinct_l = T.distinct_l + +include Make(T) diff --git a/src/th-distinct/Sidekick_th_distinct.mli b/src/th-distinct/Sidekick_th_distinct.mli new file mode 100644 index 00000000..4ce5b68a --- /dev/null +++ b/src/th-distinct/Sidekick_th_distinct.mli @@ -0,0 +1,53 @@ + +(** {1 Theory of "distinct"} + + This is an extension of the congruence closure that handles + "distinct" efficiently. + *) + +module Term = Sidekick_smt.Term + +module type ARG = sig + module T : sig + type t + type state + val pp : t Fmt.printer + val equal : t -> t -> bool + val hash : t -> int + val as_distinct : t -> t Sequence.t option + val mk_eq : state -> t -> t -> t + end + module Lit : sig + type t + val term : t -> T.t + val neg : t -> t + val sign : t -> bool + val compare : t -> t -> int + val atom : T.t -> t + val pp : t Fmt.printer + end +end + +module type S = sig + type term + type term_state + type lit + + type data + val key : (term, lit, data) Sidekick_cc.Key.t + val th : Sidekick_smt.Theory.t +end + +(* TODO: generalize theories *) +module Make(A : ARG with type T.t = Sidekick_smt.Term.t + and type T.state = Sidekick_smt.Term.state + and type Lit.t = Sidekick_smt.Lit.t) : + S with type term = A.T.t + and type lit = A.Lit.t + and type term_state = A.T.state + +val distinct : Term.state -> Term.t IArray.t -> Term.t +val distinct_l : Term.state -> Term.t list -> Term.t + +(** Default instance *) +include S with type term = Term.t and type lit = Sidekick_smt.Lit.t diff --git a/src/th-distinct/dune b/src/th-distinct/dune new file mode 100644 index 00000000..57fcb854 --- /dev/null +++ b/src/th-distinct/dune @@ -0,0 +1,7 @@ + +(library + (name Sidekick_th_distinct) + (public_name sidekick.smt.th-distinct) + (libraries containers sidekick.smt) + (flags :standard -open Sidekick_util)) +