From a4e3fd5a69d7cf7a6840e9254d826ee847665ddd Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sun, 1 Dec 2019 19:26:12 -0600 Subject: [PATCH] feat: provide simple `repr->monoid` mapping in core --- src/cc/Sidekick_cc.ml | 7 ++ src/core/Sidekick_core.ml | 100 +++++++++++++++++++++++++++- src/th-cstor/Sidekick_th_cstor.ml | 104 ++++++++++++++---------------- src/th-data/Sidekick_th_data.ml | 96 ++++++++++++++------------- src/th-data/Sidekick_th_data.mli | 57 ++++++++++++++++ 5 files changed, 262 insertions(+), 102 deletions(-) create mode 100644 src/th-data/Sidekick_th_data.mli diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index f6ad319a..5b234a02 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -289,6 +289,13 @@ module Make (A: CC_ARG) let[@inline] on_backtrack cc f : unit = Backtrack_stack.push_if_nonzero_level cc.undo f + let set_bitfield cc field b n = + let old = N.get_field field n in + if old <> b then ( + on_backtrack cc (fun () -> N.set_field field old n); + N.set_field field b n; + ) + (* check if [t] is in the congruence closure. Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *) let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t diff --git a/src/core/Sidekick_core.ml b/src/core/Sidekick_core.ml index 6be110ba..1af55db0 100644 --- a/src/core/Sidekick_core.ml +++ b/src/core/Sidekick_core.ml @@ -182,7 +182,6 @@ module type CC_S = sig and are merged automatically when classes are merged. *) val get_field : bitfield -> t -> bool - val set_field : bitfield -> bool -> t -> unit end module Expl : sig @@ -236,6 +235,10 @@ module type CC_S = sig (** Allocate a new bitfield for the nodes. See {!N.bitfield}. *) + val set_bitfield : t -> N.bitfield -> bool -> N.t -> unit + (** Set the bitfield for the node. This will be backtracked. + See {!N.bitfield}. *) + (* TODO: remove? this is managed by the solver anyway? *) val on_pre_merge : t -> ev_on_pre_merge -> unit (** Add a function to be called when two classes are merged *) @@ -658,3 +661,98 @@ module type SOLVER = sig val pp_stats : t CCFormat.printer end + +(** Helper for keeping track of state for each class *) + +module type MONOID_ARG = sig + module SI : SOLVER_INTERNAL + type t + val pp : t Fmt.printer + val name : string (* name of the monoid's value (short) *) + val of_term : SI.CC.N.t -> SI.T.Term.t -> t option + val merge : SI.CC.t -> SI.CC.N.t -> t -> SI.CC.N.t -> t -> (t, SI.CC.Expl.t) result +end + +module Monoid_of_repr(M : MONOID_ARG) : sig + type t + val create_and_setup : ?size:int -> M.SI.t -> t + val push_level : t -> unit + val pop_levels : t -> int -> unit + val mem : t -> M.SI.CC.N.t -> bool + val get : t -> M.SI.CC.N.t -> M.t option +end = struct + module SI = M.SI + module T = SI.T.Term + module N = SI.CC.N + module CC = SI.CC + module N_tbl = Backtrackable_tbl.Make(N) + module Expl = SI.CC.Expl + + type t = { + values: M.t N_tbl.t; (* repr -> value for the class *) + field_has_value: N.bitfield; (* bit in CC to filter out quickly classes without value *) + } + + let push_level self = N_tbl.push_level self.values + let pop_levels self n = N_tbl.pop_levels self.values n + + let mem self n = + let res = N.get_field self.field_has_value n in + assert (if res then N_tbl.mem self.values n else true); + res + + let get self n = N_tbl.get self.values n + + let on_new_term self cc n (t:T.t) = + match M.of_term n t with + | Some v -> + Log.debugf 20 + (fun k->k "(@[monoid[%s].on-new-term@ :n %a@ :value %a@])" + M.name N.pp n M.pp v); + SI.CC.set_bitfield cc self.field_has_value true n; + N_tbl.add self.values n v + | None -> () + + (* find cell for [n] *) + let get_cell (self:t) (n:N.t) : M.t option = + N_tbl.get self.values n + (* TODO + if N.get_field self.field_has_value n then ( + try Some (N_tbl.find self.values n) + with Not_found -> + Error.errorf "repr %a has value-field bit for %s set, but is not in table" + N.pp n M.name + ) else ( + None + ) + *) + + let on_pre_merge (self:t) cc acts n1 n2 e_n1_n2 : unit = + begin match get_cell self n1, get_cell self n2 with + | Some v1, Some v2 -> + Log.debugf 5 + (fun k->k + "(@[monoid[%s].on_pre_merge@ @[:n1 %a@ :val %a@]@ @[:n2 %a@ :val %a@]@])" + M.name N.pp n1 M.pp v1 N.pp n2 M.pp v2); + begin match M.merge cc n1 v1 n2 v2 with + | Ok v' -> + N_tbl.add self.values n1 v'; + | Error expl -> + (* add [n1=n2] to the conflict *) + let expl = Expl.mk_list [ e_n1_n2; expl; ] in + SI.CC.raise_conflict_from_expl cc acts expl + end + | None, Some cr -> + SI.CC.set_bitfield cc self.field_has_value true n1; + N_tbl.add self.values n1 cr + | Some _, None -> () (* already there on the left *) + | None, None -> () + end + + let create_and_setup ?size (solver:SI.t) : t = + let field_has_value = SI.CC.allocate_bitfield (SI.cc solver) in + let self = { values=N_tbl.create ?size (); field_has_value; } in + SI.on_cc_new_term solver (on_new_term self); + SI.on_cc_pre_merge solver (on_pre_merge self); + self +end diff --git a/src/th-cstor/Sidekick_th_cstor.ml b/src/th-cstor/Sidekick_th_cstor.ml index 1d5b666b..688e5cbd 100644 --- a/src/th-cstor/Sidekick_th_cstor.ml +++ b/src/th-cstor/Sidekick_th_cstor.ml @@ -24,71 +24,61 @@ module Make(A : ARG) : S with module A = A = struct module Fun = A.S.T.Fun module Expl = SI.CC.Expl - type cstor_repr = { - t: T.t; - n: N.t; - cstor: Fun.t; - args: T.t IArray.t; - } - (* associate to each class a unique constructor term in the class (if any) *) + module Monoid = struct + module SI = SI - module N_tbl = Backtrackable_tbl.Make(N) + (* associate to each class a unique constructor term in the class (if any) *) + type t = { + t: T.t; + n: N.t; + cstor: Fun.t; + args: T.t IArray.t; + } - type t = { - cstors: cstor_repr N_tbl.t; (* repr -> cstor for the class *) - (* TODO: also allocate a bit in CC to filter out quickly classes without cstors? *) - } + let name = name + let pp out (v:t) = + Fmt.fprintf out "(@[cstor %a@ :term %a@])" Fun.pp v.cstor T.pp v.t - let push_level self = N_tbl.push_level self.cstors - let pop_levels self n = N_tbl.pop_levels self.cstors n + (* attach data to constructor terms *) + let of_term n (t:T.t) : _ option = + match A.view_as_cstor t with + | T_cstor (cstor,args) -> Some {n; t; cstor; args} + | _ -> None - (* attach data to constructor terms *) - let on_new_term self _solver n (t:T.t) = - match A.view_as_cstor t with - | T_cstor (cstor,args) -> - Log.debugf 20 - (fun k->k "(@[th-cstor.on-new-term@ %a@ :cstor %a@ @[:args@ (@[%a@])@]@]@])" - T.pp t Fun.pp cstor (Util.pp_iarray T.pp) args); - N_tbl.add self.cstors n {n; t; cstor; args}; - | _ -> () - - let on_pre_merge (self:t) cc acts n1 n2 e_n1_n2 : unit = - begin match N_tbl.get self.cstors n1, N_tbl.get self.cstors n2 with - | Some cr1, Some cr2 -> - Log.debugf 5 - (fun k->k "(@[th-cstor.on_pre_merge@ @[:c1 %a@ (term %a)@]@ @[:c2 %a@ (term %a)@]@])" - N.pp n1 T.pp cr1.t N.pp n2 T.pp cr2.t); - (* build full explanation of why the constructor terms are equal *) - let expl = - Expl.mk_list [ - e_n1_n2; - Expl.mk_merge n1 cr1.n; - Expl.mk_merge n2 cr2.n; - ] - in - if Fun.equal cr1.cstor cr2.cstor then ( - (* same function: injectivity *) - assert (IArray.length cr1.args = IArray.length cr2.args); - IArray.iter2 - (fun u1 u2 -> SI.CC.merge_t cc u1 u2 expl) - cr1.args cr2.args - ) else ( - (* different function: disjointness *) - SI.CC.raise_conflict_from_expl cc acts expl + let merge cc n1 v1 n2 v2 : _ result = + Log.debugf 5 + (fun k->k "(@[%s.merge@ @[:c1 %a (t %a)@]@ @[:c2 %a (t %a)@]@])" + name N.pp n1 T.pp v1.t N.pp n2 T.pp v2.t); + (* build full explanation of why the constructor terms are equal *) + let expl = + Expl.mk_list [ + Expl.mk_merge n1 v1.n; + Expl.mk_merge n2 v2.n; + ] + in + if Fun.equal v1.cstor v2.cstor then ( + (* same function: injectivity *) + assert (IArray.length v1.args = IArray.length v2.args); + IArray.iter2 + (fun u1 u2 -> SI.CC.merge_t cc u1 u2 expl) + v1.args v2.args; + Ok v1 + ) else ( + (* different function: disjointness *) + Error expl ) - | None, Some cr -> - N_tbl.add self.cstors n1 cr - | Some _, None -> () (* already there on the left *) - | None, None -> () - end + end + + module ST = Sidekick_core.Monoid_of_repr(Monoid) + + type t = ST.t + + let push_level = ST.push_level + let pop_levels = ST.pop_levels let create_and_setup (solver:SI.t) : t = - let self = { - cstors=N_tbl.create ~size:32 (); - } in Log.debug 1 "(setup :th-cstor)"; - SI.on_cc_new_term solver (on_new_term self); - SI.on_cc_pre_merge solver (on_pre_merge self); + let self = ST.create_and_setup ~size:32 solver in self let theory = diff --git a/src/th-data/Sidekick_th_data.ml b/src/th-data/Sidekick_th_data.ml index 2e080b32..c3c2ce06 100644 --- a/src/th-data/Sidekick_th_data.ml +++ b/src/th-data/Sidekick_th_data.ml @@ -138,19 +138,57 @@ module Make(A : ARG) : S with module A = A = struct module Card = Compute_card(A) - type cstor_repr = { - t: T.t; - n: N.t; - cstor: A.Cstor.t; - args: T.t IArray.t; - } - (* associate to each class a unique constructor term in the class (if any) *) + module Monoid_cstor = struct + module SI = SI + (* associate to each class a unique constructor term in the class (if any) *) + type t = { + t: T.t; + n: N.t; + cstor: A.Cstor.t; + args: T.t IArray.t; + } + + let name = "th-data.cstor" + let pp out (v:t) = + Fmt.fprintf out "(@[cstor %a@ :term %a@])" A.Cstor.pp v.cstor T.pp v.t + + (* attach data to constructor terms *) + let of_term n (t:T.t) : _ option = + match A.view_as_data t with + | T_cstor (cstor,args) -> Some {n; t; cstor; args} + | _ -> None + + let merge cc n1 v1 n2 v2 : _ result = + Log.debugf 5 + (fun k->k "(@[%s.merge@ @[:c1 %a (t %a)@]@ @[:c2 %a (t %a)@]@])" + name N.pp n1 T.pp v1.t N.pp n2 T.pp v2.t); + (* build full explanation of why the constructor terms are equal *) + let expl = + Expl.mk_list [ + Expl.mk_merge n1 v1.n; + Expl.mk_merge n2 v2.n; + ] + in + if A.Cstor.equal v1.cstor v2.cstor then ( + (* same function: injectivity *) + assert (IArray.length v1.args = IArray.length v2.args); + IArray.iter2 + (fun u1 u2 -> SI.CC.merge_t cc u1 u2 expl) + v1.args v2.args; + Ok v1 + ) else ( + (* different function: disjointness *) + Error expl + ) + end + + module ST_cstors = Sidekick_core.Monoid_of_repr(Monoid_cstor) module N_tbl = Backtrackable_tbl.Make(N) type t = { tst: T.state; - cstors: cstor_repr N_tbl.t; (* repr -> cstor for the class *) + cstors: ST_cstors.t; (* repr -> cstor for the class *) cards: Card.t; (* remember finiteness *) to_decide: unit N_tbl.t; (* set of terms to decide. *) case_split_done: unit T.Tbl.t; (* set of terms for which case split is done *) @@ -159,18 +197,19 @@ module Make(A : ARG) : S with module A = A = struct } let push_level self = - N_tbl.push_level self.cstors; + ST_cstors.push_level self.cstors; N_tbl.push_level self.to_decide; () let pop_levels self n = - N_tbl.pop_levels self.cstors n; + ST_cstors.pop_levels self.cstors n; N_tbl.pop_levels self.to_decide n; () (* TODO: select/is-a *) (* TODO: acyclicity *) + (* TODO: remove (* attach data to constructor terms *) let on_new_term_look_at_shape self n (t:T.t) = match A.view_as_data t with @@ -193,6 +232,7 @@ module Make(A : ARG) : S with module A = A = struct () (* N_tbl.add self.cstors n {n; t; cstor; args}; *) | T_other _ -> () + *) (* remember terms of a datatype *) let on_new_term_look_at_ty (self:t) n (t:T.t) : unit = @@ -211,40 +251,9 @@ module Make(A : ARG) : S with module A = A = struct | _ -> () let on_new_term self _solver n t = - on_new_term_look_at_shape self n t; on_new_term_look_at_ty self n t; () - let on_pre_merge (self:t) cc acts n1 n2 e_n1_n2 : unit = - begin match N_tbl.get self.cstors n1, N_tbl.get self.cstors n2 with - | Some cr1, Some cr2 -> - Log.debugf 5 - (fun k->k "(@[th-cstor.on_pre_merge@ @[:c1 %a@ (term %a)@]@ @[:c2 %a@ (term %a)@]@])" - N.pp n1 T.pp cr1.t N.pp n2 T.pp cr2.t); - (* build full explanation of why the constructor terms are equal *) - let expl = - Expl.mk_list [ - e_n1_n2; - Expl.mk_merge n1 cr1.n; - Expl.mk_merge n2 cr2.n; - ] - in - if A.Cstor.equal cr1.cstor cr2.cstor then ( - (* same function: injectivity *) - assert (IArray.length cr1.args = IArray.length cr2.args); - IArray.iter2 - (fun u1 u2 -> SI.CC.merge_t cc u1 u2 expl) - cr1.args cr2.args - ) else ( - (* different function: disjointness *) - SI.CC.raise_conflict_from_expl cc acts expl - ) - | None, Some cr -> - N_tbl.add self.cstors n1 cr - | Some _, None -> () (* already there on the left *) - | None, None -> () - end - let cstors_of_ty (ty:Ty.t) : A.Cstor.t Iter.t = match A.as_datatype ty with | Ty_data {cstors} -> cstors @@ -258,7 +267,7 @@ module Make(A : ARG) : S with module A = A = struct |> Iter.map (fun (n,_) -> SI.cc_find solver n) |> Iter.filter (fun n -> - not (N_tbl.mem self.cstors n) && + not (ST_cstors.mem self.cstors n) && not (T.Tbl.mem self.case_split_done (N.term n))) |> Iter.to_rev_list in @@ -297,14 +306,13 @@ module Make(A : ARG) : S with module A = A = struct let create_and_setup (solver:SI.t) : t = let self = { tst=SI.tst solver; - cstors=N_tbl.create ~size:32 (); + cstors=ST_cstors.create_and_setup ~size:32 solver; to_decide=N_tbl.create ~size:16 (); case_split_done=T.Tbl.create 16; cards=Card.create(); } in Log.debugf 1 (fun k->k "(setup :%s)" name); SI.on_cc_new_term solver (on_new_term self); - SI.on_cc_pre_merge solver (on_pre_merge self); SI.on_final_check solver (on_final_check self); self diff --git a/src/th-data/Sidekick_th_data.mli b/src/th-data/Sidekick_th_data.mli new file mode 100644 index 00000000..dd731307 --- /dev/null +++ b/src/th-data/Sidekick_th_data.mli @@ -0,0 +1,57 @@ + +(** Datatype-oriented view of terms. + ['c] is the representation of constructors + ['t] is the representation of terms +*) +type ('c,'t) data_view = + | T_cstor of 'c * 't IArray.t + | T_select of 'c * int * 't + | T_is_a of 'c * 't + | T_other of 't + +(** View of types in a way that is directly useful for the theory of datatypes *) +type ('c, 'ty) data_ty_view = + | Ty_arrow of 'ty Iter.t * 'ty + | Ty_app of { + args: 'ty Iter.t; + } + | Ty_data of { + cstors: 'c; + } + | Ty_other + +module type ARG = sig + module S : Sidekick_core.SOLVER + + module Cstor : sig + type t + val ty_args : t -> S.T.Ty.t Iter.t + val pp : t Fmt.printer + val equal : t -> t -> bool + end + + val as_datatype : S.T.Ty.t -> (Cstor.t Iter.t, S.T.Ty.t) data_ty_view + (** Try to view type as a datatype (with its constructors) *) + + val view_as_data : S.T.Term.t -> (Cstor.t, S.T.Term.t) data_view + (** Try to view term as a datatype term *) + + val mk_cstor : S.T.Term.state -> Cstor.t -> S.T.Term.t IArray.t -> S.T.Term.t + (** Make a constructor application term *) + + val mk_is_a: S.T.Term.state -> Cstor.t -> S.T.Term.t -> S.T.Term.t + (** Make a [is-a] term *) + + val ty_is_finite : S.T.Ty.t -> bool + (** Is the given type known to be finite? *) + + val ty_set_is_finite : S.T.Ty.t -> bool -> unit + (** Modify the "finite" field (see {!ty_is_finite}) *) +end + +module type S = sig + module A : ARG + val theory : A.S.theory +end + +module Make(A : ARG) : S with module A = A