mirror of
https://github.com/c-cube/sidekick.git
synced 2026-01-22 17:36:41 -05:00
feat: provide simple repr->monoid mapping in core
This commit is contained in:
parent
7c951c08ff
commit
a4e3fd5a69
5 changed files with 262 additions and 102 deletions
|
|
@ -289,6 +289,13 @@ module Make (A: CC_ARG)
|
|||
let[@inline] on_backtrack cc f : unit =
|
||||
Backtrack_stack.push_if_nonzero_level cc.undo f
|
||||
|
||||
let set_bitfield cc field b n =
|
||||
let old = N.get_field field n in
|
||||
if old <> b then (
|
||||
on_backtrack cc (fun () -> N.set_field field old n);
|
||||
N.set_field field b n;
|
||||
)
|
||||
|
||||
(* check if [t] is in the congruence closure.
|
||||
Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *)
|
||||
let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t
|
||||
|
|
|
|||
|
|
@ -182,7 +182,6 @@ module type CC_S = sig
|
|||
and are merged automatically when classes are merged. *)
|
||||
|
||||
val get_field : bitfield -> t -> bool
|
||||
val set_field : bitfield -> bool -> t -> unit
|
||||
end
|
||||
|
||||
module Expl : sig
|
||||
|
|
@ -236,6 +235,10 @@ module type CC_S = sig
|
|||
(** Allocate a new bitfield for the nodes.
|
||||
See {!N.bitfield}. *)
|
||||
|
||||
val set_bitfield : t -> N.bitfield -> bool -> N.t -> unit
|
||||
(** Set the bitfield for the node. This will be backtracked.
|
||||
See {!N.bitfield}. *)
|
||||
|
||||
(* TODO: remove? this is managed by the solver anyway? *)
|
||||
val on_pre_merge : t -> ev_on_pre_merge -> unit
|
||||
(** Add a function to be called when two classes are merged *)
|
||||
|
|
@ -658,3 +661,98 @@ module type SOLVER = sig
|
|||
|
||||
val pp_stats : t CCFormat.printer
|
||||
end
|
||||
|
||||
(** Helper for keeping track of state for each class *)
|
||||
|
||||
module type MONOID_ARG = sig
|
||||
module SI : SOLVER_INTERNAL
|
||||
type t
|
||||
val pp : t Fmt.printer
|
||||
val name : string (* name of the monoid's value (short) *)
|
||||
val of_term : SI.CC.N.t -> SI.T.Term.t -> t option
|
||||
val merge : SI.CC.t -> SI.CC.N.t -> t -> SI.CC.N.t -> t -> (t, SI.CC.Expl.t) result
|
||||
end
|
||||
|
||||
module Monoid_of_repr(M : MONOID_ARG) : sig
|
||||
type t
|
||||
val create_and_setup : ?size:int -> M.SI.t -> t
|
||||
val push_level : t -> unit
|
||||
val pop_levels : t -> int -> unit
|
||||
val mem : t -> M.SI.CC.N.t -> bool
|
||||
val get : t -> M.SI.CC.N.t -> M.t option
|
||||
end = struct
|
||||
module SI = M.SI
|
||||
module T = SI.T.Term
|
||||
module N = SI.CC.N
|
||||
module CC = SI.CC
|
||||
module N_tbl = Backtrackable_tbl.Make(N)
|
||||
module Expl = SI.CC.Expl
|
||||
|
||||
type t = {
|
||||
values: M.t N_tbl.t; (* repr -> value for the class *)
|
||||
field_has_value: N.bitfield; (* bit in CC to filter out quickly classes without value *)
|
||||
}
|
||||
|
||||
let push_level self = N_tbl.push_level self.values
|
||||
let pop_levels self n = N_tbl.pop_levels self.values n
|
||||
|
||||
let mem self n =
|
||||
let res = N.get_field self.field_has_value n in
|
||||
assert (if res then N_tbl.mem self.values n else true);
|
||||
res
|
||||
|
||||
let get self n = N_tbl.get self.values n
|
||||
|
||||
let on_new_term self cc n (t:T.t) =
|
||||
match M.of_term n t with
|
||||
| Some v ->
|
||||
Log.debugf 20
|
||||
(fun k->k "(@[monoid[%s].on-new-term@ :n %a@ :value %a@])"
|
||||
M.name N.pp n M.pp v);
|
||||
SI.CC.set_bitfield cc self.field_has_value true n;
|
||||
N_tbl.add self.values n v
|
||||
| None -> ()
|
||||
|
||||
(* find cell for [n] *)
|
||||
let get_cell (self:t) (n:N.t) : M.t option =
|
||||
N_tbl.get self.values n
|
||||
(* TODO
|
||||
if N.get_field self.field_has_value n then (
|
||||
try Some (N_tbl.find self.values n)
|
||||
with Not_found ->
|
||||
Error.errorf "repr %a has value-field bit for %s set, but is not in table"
|
||||
N.pp n M.name
|
||||
) else (
|
||||
None
|
||||
)
|
||||
*)
|
||||
|
||||
let on_pre_merge (self:t) cc acts n1 n2 e_n1_n2 : unit =
|
||||
begin match get_cell self n1, get_cell self n2 with
|
||||
| Some v1, Some v2 ->
|
||||
Log.debugf 5
|
||||
(fun k->k
|
||||
"(@[monoid[%s].on_pre_merge@ @[:n1 %a@ :val %a@]@ @[:n2 %a@ :val %a@]@])"
|
||||
M.name N.pp n1 M.pp v1 N.pp n2 M.pp v2);
|
||||
begin match M.merge cc n1 v1 n2 v2 with
|
||||
| Ok v' ->
|
||||
N_tbl.add self.values n1 v';
|
||||
| Error expl ->
|
||||
(* add [n1=n2] to the conflict *)
|
||||
let expl = Expl.mk_list [ e_n1_n2; expl; ] in
|
||||
SI.CC.raise_conflict_from_expl cc acts expl
|
||||
end
|
||||
| None, Some cr ->
|
||||
SI.CC.set_bitfield cc self.field_has_value true n1;
|
||||
N_tbl.add self.values n1 cr
|
||||
| Some _, None -> () (* already there on the left *)
|
||||
| None, None -> ()
|
||||
end
|
||||
|
||||
let create_and_setup ?size (solver:SI.t) : t =
|
||||
let field_has_value = SI.CC.allocate_bitfield (SI.cc solver) in
|
||||
let self = { values=N_tbl.create ?size (); field_has_value; } in
|
||||
SI.on_cc_new_term solver (on_new_term self);
|
||||
SI.on_cc_pre_merge solver (on_pre_merge self);
|
||||
self
|
||||
end
|
||||
|
|
|
|||
|
|
@ -24,71 +24,61 @@ module Make(A : ARG) : S with module A = A = struct
|
|||
module Fun = A.S.T.Fun
|
||||
module Expl = SI.CC.Expl
|
||||
|
||||
type cstor_repr = {
|
||||
t: T.t;
|
||||
n: N.t;
|
||||
cstor: Fun.t;
|
||||
args: T.t IArray.t;
|
||||
}
|
||||
(* associate to each class a unique constructor term in the class (if any) *)
|
||||
module Monoid = struct
|
||||
module SI = SI
|
||||
|
||||
module N_tbl = Backtrackable_tbl.Make(N)
|
||||
(* associate to each class a unique constructor term in the class (if any) *)
|
||||
type t = {
|
||||
t: T.t;
|
||||
n: N.t;
|
||||
cstor: Fun.t;
|
||||
args: T.t IArray.t;
|
||||
}
|
||||
|
||||
type t = {
|
||||
cstors: cstor_repr N_tbl.t; (* repr -> cstor for the class *)
|
||||
(* TODO: also allocate a bit in CC to filter out quickly classes without cstors? *)
|
||||
}
|
||||
let name = name
|
||||
let pp out (v:t) =
|
||||
Fmt.fprintf out "(@[cstor %a@ :term %a@])" Fun.pp v.cstor T.pp v.t
|
||||
|
||||
let push_level self = N_tbl.push_level self.cstors
|
||||
let pop_levels self n = N_tbl.pop_levels self.cstors n
|
||||
(* attach data to constructor terms *)
|
||||
let of_term n (t:T.t) : _ option =
|
||||
match A.view_as_cstor t with
|
||||
| T_cstor (cstor,args) -> Some {n; t; cstor; args}
|
||||
| _ -> None
|
||||
|
||||
(* attach data to constructor terms *)
|
||||
let on_new_term self _solver n (t:T.t) =
|
||||
match A.view_as_cstor t with
|
||||
| T_cstor (cstor,args) ->
|
||||
Log.debugf 20
|
||||
(fun k->k "(@[th-cstor.on-new-term@ %a@ :cstor %a@ @[:args@ (@[%a@])@]@]@])"
|
||||
T.pp t Fun.pp cstor (Util.pp_iarray T.pp) args);
|
||||
N_tbl.add self.cstors n {n; t; cstor; args};
|
||||
| _ -> ()
|
||||
|
||||
let on_pre_merge (self:t) cc acts n1 n2 e_n1_n2 : unit =
|
||||
begin match N_tbl.get self.cstors n1, N_tbl.get self.cstors n2 with
|
||||
| Some cr1, Some cr2 ->
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[th-cstor.on_pre_merge@ @[:c1 %a@ (term %a)@]@ @[:c2 %a@ (term %a)@]@])"
|
||||
N.pp n1 T.pp cr1.t N.pp n2 T.pp cr2.t);
|
||||
(* build full explanation of why the constructor terms are equal *)
|
||||
let expl =
|
||||
Expl.mk_list [
|
||||
e_n1_n2;
|
||||
Expl.mk_merge n1 cr1.n;
|
||||
Expl.mk_merge n2 cr2.n;
|
||||
]
|
||||
in
|
||||
if Fun.equal cr1.cstor cr2.cstor then (
|
||||
(* same function: injectivity *)
|
||||
assert (IArray.length cr1.args = IArray.length cr2.args);
|
||||
IArray.iter2
|
||||
(fun u1 u2 -> SI.CC.merge_t cc u1 u2 expl)
|
||||
cr1.args cr2.args
|
||||
) else (
|
||||
(* different function: disjointness *)
|
||||
SI.CC.raise_conflict_from_expl cc acts expl
|
||||
let merge cc n1 v1 n2 v2 : _ result =
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[%s.merge@ @[:c1 %a (t %a)@]@ @[:c2 %a (t %a)@]@])"
|
||||
name N.pp n1 T.pp v1.t N.pp n2 T.pp v2.t);
|
||||
(* build full explanation of why the constructor terms are equal *)
|
||||
let expl =
|
||||
Expl.mk_list [
|
||||
Expl.mk_merge n1 v1.n;
|
||||
Expl.mk_merge n2 v2.n;
|
||||
]
|
||||
in
|
||||
if Fun.equal v1.cstor v2.cstor then (
|
||||
(* same function: injectivity *)
|
||||
assert (IArray.length v1.args = IArray.length v2.args);
|
||||
IArray.iter2
|
||||
(fun u1 u2 -> SI.CC.merge_t cc u1 u2 expl)
|
||||
v1.args v2.args;
|
||||
Ok v1
|
||||
) else (
|
||||
(* different function: disjointness *)
|
||||
Error expl
|
||||
)
|
||||
| None, Some cr ->
|
||||
N_tbl.add self.cstors n1 cr
|
||||
| Some _, None -> () (* already there on the left *)
|
||||
| None, None -> ()
|
||||
end
|
||||
end
|
||||
|
||||
module ST = Sidekick_core.Monoid_of_repr(Monoid)
|
||||
|
||||
type t = ST.t
|
||||
|
||||
let push_level = ST.push_level
|
||||
let pop_levels = ST.pop_levels
|
||||
|
||||
let create_and_setup (solver:SI.t) : t =
|
||||
let self = {
|
||||
cstors=N_tbl.create ~size:32 ();
|
||||
} in
|
||||
Log.debug 1 "(setup :th-cstor)";
|
||||
SI.on_cc_new_term solver (on_new_term self);
|
||||
SI.on_cc_pre_merge solver (on_pre_merge self);
|
||||
let self = ST.create_and_setup ~size:32 solver in
|
||||
self
|
||||
|
||||
let theory =
|
||||
|
|
|
|||
|
|
@ -138,19 +138,57 @@ module Make(A : ARG) : S with module A = A = struct
|
|||
|
||||
module Card = Compute_card(A)
|
||||
|
||||
type cstor_repr = {
|
||||
t: T.t;
|
||||
n: N.t;
|
||||
cstor: A.Cstor.t;
|
||||
args: T.t IArray.t;
|
||||
}
|
||||
(* associate to each class a unique constructor term in the class (if any) *)
|
||||
module Monoid_cstor = struct
|
||||
module SI = SI
|
||||
|
||||
(* associate to each class a unique constructor term in the class (if any) *)
|
||||
type t = {
|
||||
t: T.t;
|
||||
n: N.t;
|
||||
cstor: A.Cstor.t;
|
||||
args: T.t IArray.t;
|
||||
}
|
||||
|
||||
let name = "th-data.cstor"
|
||||
let pp out (v:t) =
|
||||
Fmt.fprintf out "(@[cstor %a@ :term %a@])" A.Cstor.pp v.cstor T.pp v.t
|
||||
|
||||
(* attach data to constructor terms *)
|
||||
let of_term n (t:T.t) : _ option =
|
||||
match A.view_as_data t with
|
||||
| T_cstor (cstor,args) -> Some {n; t; cstor; args}
|
||||
| _ -> None
|
||||
|
||||
let merge cc n1 v1 n2 v2 : _ result =
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[%s.merge@ @[:c1 %a (t %a)@]@ @[:c2 %a (t %a)@]@])"
|
||||
name N.pp n1 T.pp v1.t N.pp n2 T.pp v2.t);
|
||||
(* build full explanation of why the constructor terms are equal *)
|
||||
let expl =
|
||||
Expl.mk_list [
|
||||
Expl.mk_merge n1 v1.n;
|
||||
Expl.mk_merge n2 v2.n;
|
||||
]
|
||||
in
|
||||
if A.Cstor.equal v1.cstor v2.cstor then (
|
||||
(* same function: injectivity *)
|
||||
assert (IArray.length v1.args = IArray.length v2.args);
|
||||
IArray.iter2
|
||||
(fun u1 u2 -> SI.CC.merge_t cc u1 u2 expl)
|
||||
v1.args v2.args;
|
||||
Ok v1
|
||||
) else (
|
||||
(* different function: disjointness *)
|
||||
Error expl
|
||||
)
|
||||
end
|
||||
|
||||
module ST_cstors = Sidekick_core.Monoid_of_repr(Monoid_cstor)
|
||||
module N_tbl = Backtrackable_tbl.Make(N)
|
||||
|
||||
type t = {
|
||||
tst: T.state;
|
||||
cstors: cstor_repr N_tbl.t; (* repr -> cstor for the class *)
|
||||
cstors: ST_cstors.t; (* repr -> cstor for the class *)
|
||||
cards: Card.t; (* remember finiteness *)
|
||||
to_decide: unit N_tbl.t; (* set of terms to decide. *)
|
||||
case_split_done: unit T.Tbl.t; (* set of terms for which case split is done *)
|
||||
|
|
@ -159,18 +197,19 @@ module Make(A : ARG) : S with module A = A = struct
|
|||
}
|
||||
|
||||
let push_level self =
|
||||
N_tbl.push_level self.cstors;
|
||||
ST_cstors.push_level self.cstors;
|
||||
N_tbl.push_level self.to_decide;
|
||||
()
|
||||
|
||||
let pop_levels self n =
|
||||
N_tbl.pop_levels self.cstors n;
|
||||
ST_cstors.pop_levels self.cstors n;
|
||||
N_tbl.pop_levels self.to_decide n;
|
||||
()
|
||||
|
||||
(* TODO: select/is-a *)
|
||||
(* TODO: acyclicity *)
|
||||
|
||||
(* TODO: remove
|
||||
(* attach data to constructor terms *)
|
||||
let on_new_term_look_at_shape self n (t:T.t) =
|
||||
match A.view_as_data t with
|
||||
|
|
@ -193,6 +232,7 @@ module Make(A : ARG) : S with module A = A = struct
|
|||
()
|
||||
(* N_tbl.add self.cstors n {n; t; cstor; args}; *)
|
||||
| T_other _ -> ()
|
||||
*)
|
||||
|
||||
(* remember terms of a datatype *)
|
||||
let on_new_term_look_at_ty (self:t) n (t:T.t) : unit =
|
||||
|
|
@ -211,40 +251,9 @@ module Make(A : ARG) : S with module A = A = struct
|
|||
| _ -> ()
|
||||
|
||||
let on_new_term self _solver n t =
|
||||
on_new_term_look_at_shape self n t;
|
||||
on_new_term_look_at_ty self n t;
|
||||
()
|
||||
|
||||
let on_pre_merge (self:t) cc acts n1 n2 e_n1_n2 : unit =
|
||||
begin match N_tbl.get self.cstors n1, N_tbl.get self.cstors n2 with
|
||||
| Some cr1, Some cr2 ->
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[th-cstor.on_pre_merge@ @[:c1 %a@ (term %a)@]@ @[:c2 %a@ (term %a)@]@])"
|
||||
N.pp n1 T.pp cr1.t N.pp n2 T.pp cr2.t);
|
||||
(* build full explanation of why the constructor terms are equal *)
|
||||
let expl =
|
||||
Expl.mk_list [
|
||||
e_n1_n2;
|
||||
Expl.mk_merge n1 cr1.n;
|
||||
Expl.mk_merge n2 cr2.n;
|
||||
]
|
||||
in
|
||||
if A.Cstor.equal cr1.cstor cr2.cstor then (
|
||||
(* same function: injectivity *)
|
||||
assert (IArray.length cr1.args = IArray.length cr2.args);
|
||||
IArray.iter2
|
||||
(fun u1 u2 -> SI.CC.merge_t cc u1 u2 expl)
|
||||
cr1.args cr2.args
|
||||
) else (
|
||||
(* different function: disjointness *)
|
||||
SI.CC.raise_conflict_from_expl cc acts expl
|
||||
)
|
||||
| None, Some cr ->
|
||||
N_tbl.add self.cstors n1 cr
|
||||
| Some _, None -> () (* already there on the left *)
|
||||
| None, None -> ()
|
||||
end
|
||||
|
||||
let cstors_of_ty (ty:Ty.t) : A.Cstor.t Iter.t =
|
||||
match A.as_datatype ty with
|
||||
| Ty_data {cstors} -> cstors
|
||||
|
|
@ -258,7 +267,7 @@ module Make(A : ARG) : S with module A = A = struct
|
|||
|> Iter.map (fun (n,_) -> SI.cc_find solver n)
|
||||
|> Iter.filter
|
||||
(fun n ->
|
||||
not (N_tbl.mem self.cstors n) &&
|
||||
not (ST_cstors.mem self.cstors n) &&
|
||||
not (T.Tbl.mem self.case_split_done (N.term n)))
|
||||
|> Iter.to_rev_list
|
||||
in
|
||||
|
|
@ -297,14 +306,13 @@ module Make(A : ARG) : S with module A = A = struct
|
|||
let create_and_setup (solver:SI.t) : t =
|
||||
let self = {
|
||||
tst=SI.tst solver;
|
||||
cstors=N_tbl.create ~size:32 ();
|
||||
cstors=ST_cstors.create_and_setup ~size:32 solver;
|
||||
to_decide=N_tbl.create ~size:16 ();
|
||||
case_split_done=T.Tbl.create 16;
|
||||
cards=Card.create();
|
||||
} in
|
||||
Log.debugf 1 (fun k->k "(setup :%s)" name);
|
||||
SI.on_cc_new_term solver (on_new_term self);
|
||||
SI.on_cc_pre_merge solver (on_pre_merge self);
|
||||
SI.on_final_check solver (on_final_check self);
|
||||
self
|
||||
|
||||
|
|
|
|||
57
src/th-data/Sidekick_th_data.mli
Normal file
57
src/th-data/Sidekick_th_data.mli
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
|
||||
(** Datatype-oriented view of terms.
|
||||
['c] is the representation of constructors
|
||||
['t] is the representation of terms
|
||||
*)
|
||||
type ('c,'t) data_view =
|
||||
| T_cstor of 'c * 't IArray.t
|
||||
| T_select of 'c * int * 't
|
||||
| T_is_a of 'c * 't
|
||||
| T_other of 't
|
||||
|
||||
(** View of types in a way that is directly useful for the theory of datatypes *)
|
||||
type ('c, 'ty) data_ty_view =
|
||||
| Ty_arrow of 'ty Iter.t * 'ty
|
||||
| Ty_app of {
|
||||
args: 'ty Iter.t;
|
||||
}
|
||||
| Ty_data of {
|
||||
cstors: 'c;
|
||||
}
|
||||
| Ty_other
|
||||
|
||||
module type ARG = sig
|
||||
module S : Sidekick_core.SOLVER
|
||||
|
||||
module Cstor : sig
|
||||
type t
|
||||
val ty_args : t -> S.T.Ty.t Iter.t
|
||||
val pp : t Fmt.printer
|
||||
val equal : t -> t -> bool
|
||||
end
|
||||
|
||||
val as_datatype : S.T.Ty.t -> (Cstor.t Iter.t, S.T.Ty.t) data_ty_view
|
||||
(** Try to view type as a datatype (with its constructors) *)
|
||||
|
||||
val view_as_data : S.T.Term.t -> (Cstor.t, S.T.Term.t) data_view
|
||||
(** Try to view term as a datatype term *)
|
||||
|
||||
val mk_cstor : S.T.Term.state -> Cstor.t -> S.T.Term.t IArray.t -> S.T.Term.t
|
||||
(** Make a constructor application term *)
|
||||
|
||||
val mk_is_a: S.T.Term.state -> Cstor.t -> S.T.Term.t -> S.T.Term.t
|
||||
(** Make a [is-a] term *)
|
||||
|
||||
val ty_is_finite : S.T.Ty.t -> bool
|
||||
(** Is the given type known to be finite? *)
|
||||
|
||||
val ty_set_is_finite : S.T.Ty.t -> bool -> unit
|
||||
(** Modify the "finite" field (see {!ty_is_finite}) *)
|
||||
end
|
||||
|
||||
module type S = sig
|
||||
module A : ARG
|
||||
val theory : A.S.theory
|
||||
end
|
||||
|
||||
module Make(A : ARG) : S with module A = A
|
||||
Loading…
Add table
Reference in a new issue