feat: provide simple repr->monoid mapping in core

This commit is contained in:
Simon Cruanes 2019-12-01 19:26:12 -06:00
parent 7c951c08ff
commit a4e3fd5a69
5 changed files with 262 additions and 102 deletions

View file

@ -289,6 +289,13 @@ module Make (A: CC_ARG)
let[@inline] on_backtrack cc f : unit =
Backtrack_stack.push_if_nonzero_level cc.undo f
let set_bitfield cc field b n =
let old = N.get_field field n in
if old <> b then (
on_backtrack cc (fun () -> N.set_field field old n);
N.set_field field b n;
)
(* check if [t] is in the congruence closure.
Invariant: [in_cc t do_cc t => forall u subterm t, in_cc u] *)
let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t

View file

@ -182,7 +182,6 @@ module type CC_S = sig
and are merged automatically when classes are merged. *)
val get_field : bitfield -> t -> bool
val set_field : bitfield -> bool -> t -> unit
end
module Expl : sig
@ -236,6 +235,10 @@ module type CC_S = sig
(** Allocate a new bitfield for the nodes.
See {!N.bitfield}. *)
val set_bitfield : t -> N.bitfield -> bool -> N.t -> unit
(** Set the bitfield for the node. This will be backtracked.
See {!N.bitfield}. *)
(* TODO: remove? this is managed by the solver anyway? *)
val on_pre_merge : t -> ev_on_pre_merge -> unit
(** Add a function to be called when two classes are merged *)
@ -658,3 +661,98 @@ module type SOLVER = sig
val pp_stats : t CCFormat.printer
end
(** Helper for keeping track of state for each class *)
module type MONOID_ARG = sig
module SI : SOLVER_INTERNAL
type t
val pp : t Fmt.printer
val name : string (* name of the monoid's value (short) *)
val of_term : SI.CC.N.t -> SI.T.Term.t -> t option
val merge : SI.CC.t -> SI.CC.N.t -> t -> SI.CC.N.t -> t -> (t, SI.CC.Expl.t) result
end
module Monoid_of_repr(M : MONOID_ARG) : sig
type t
val create_and_setup : ?size:int -> M.SI.t -> t
val push_level : t -> unit
val pop_levels : t -> int -> unit
val mem : t -> M.SI.CC.N.t -> bool
val get : t -> M.SI.CC.N.t -> M.t option
end = struct
module SI = M.SI
module T = SI.T.Term
module N = SI.CC.N
module CC = SI.CC
module N_tbl = Backtrackable_tbl.Make(N)
module Expl = SI.CC.Expl
type t = {
values: M.t N_tbl.t; (* repr -> value for the class *)
field_has_value: N.bitfield; (* bit in CC to filter out quickly classes without value *)
}
let push_level self = N_tbl.push_level self.values
let pop_levels self n = N_tbl.pop_levels self.values n
let mem self n =
let res = N.get_field self.field_has_value n in
assert (if res then N_tbl.mem self.values n else true);
res
let get self n = N_tbl.get self.values n
let on_new_term self cc n (t:T.t) =
match M.of_term n t with
| Some v ->
Log.debugf 20
(fun k->k "(@[monoid[%s].on-new-term@ :n %a@ :value %a@])"
M.name N.pp n M.pp v);
SI.CC.set_bitfield cc self.field_has_value true n;
N_tbl.add self.values n v
| None -> ()
(* find cell for [n] *)
let get_cell (self:t) (n:N.t) : M.t option =
N_tbl.get self.values n
(* TODO
if N.get_field self.field_has_value n then (
try Some (N_tbl.find self.values n)
with Not_found ->
Error.errorf "repr %a has value-field bit for %s set, but is not in table"
N.pp n M.name
) else (
None
)
*)
let on_pre_merge (self:t) cc acts n1 n2 e_n1_n2 : unit =
begin match get_cell self n1, get_cell self n2 with
| Some v1, Some v2 ->
Log.debugf 5
(fun k->k
"(@[monoid[%s].on_pre_merge@ @[:n1 %a@ :val %a@]@ @[:n2 %a@ :val %a@]@])"
M.name N.pp n1 M.pp v1 N.pp n2 M.pp v2);
begin match M.merge cc n1 v1 n2 v2 with
| Ok v' ->
N_tbl.add self.values n1 v';
| Error expl ->
(* add [n1=n2] to the conflict *)
let expl = Expl.mk_list [ e_n1_n2; expl; ] in
SI.CC.raise_conflict_from_expl cc acts expl
end
| None, Some cr ->
SI.CC.set_bitfield cc self.field_has_value true n1;
N_tbl.add self.values n1 cr
| Some _, None -> () (* already there on the left *)
| None, None -> ()
end
let create_and_setup ?size (solver:SI.t) : t =
let field_has_value = SI.CC.allocate_bitfield (SI.cc solver) in
let self = { values=N_tbl.create ?size (); field_has_value; } in
SI.on_cc_new_term solver (on_new_term self);
SI.on_cc_pre_merge solver (on_pre_merge self);
self
end

View file

@ -24,71 +24,61 @@ module Make(A : ARG) : S with module A = A = struct
module Fun = A.S.T.Fun
module Expl = SI.CC.Expl
type cstor_repr = {
t: T.t;
n: N.t;
cstor: Fun.t;
args: T.t IArray.t;
}
(* associate to each class a unique constructor term in the class (if any) *)
module Monoid = struct
module SI = SI
module N_tbl = Backtrackable_tbl.Make(N)
(* associate to each class a unique constructor term in the class (if any) *)
type t = {
t: T.t;
n: N.t;
cstor: Fun.t;
args: T.t IArray.t;
}
type t = {
cstors: cstor_repr N_tbl.t; (* repr -> cstor for the class *)
(* TODO: also allocate a bit in CC to filter out quickly classes without cstors? *)
}
let name = name
let pp out (v:t) =
Fmt.fprintf out "(@[cstor %a@ :term %a@])" Fun.pp v.cstor T.pp v.t
let push_level self = N_tbl.push_level self.cstors
let pop_levels self n = N_tbl.pop_levels self.cstors n
(* attach data to constructor terms *)
let of_term n (t:T.t) : _ option =
match A.view_as_cstor t with
| T_cstor (cstor,args) -> Some {n; t; cstor; args}
| _ -> None
(* attach data to constructor terms *)
let on_new_term self _solver n (t:T.t) =
match A.view_as_cstor t with
| T_cstor (cstor,args) ->
Log.debugf 20
(fun k->k "(@[th-cstor.on-new-term@ %a@ :cstor %a@ @[:args@ (@[%a@])@]@]@])"
T.pp t Fun.pp cstor (Util.pp_iarray T.pp) args);
N_tbl.add self.cstors n {n; t; cstor; args};
| _ -> ()
let on_pre_merge (self:t) cc acts n1 n2 e_n1_n2 : unit =
begin match N_tbl.get self.cstors n1, N_tbl.get self.cstors n2 with
| Some cr1, Some cr2 ->
Log.debugf 5
(fun k->k "(@[th-cstor.on_pre_merge@ @[:c1 %a@ (term %a)@]@ @[:c2 %a@ (term %a)@]@])"
N.pp n1 T.pp cr1.t N.pp n2 T.pp cr2.t);
(* build full explanation of why the constructor terms are equal *)
let expl =
Expl.mk_list [
e_n1_n2;
Expl.mk_merge n1 cr1.n;
Expl.mk_merge n2 cr2.n;
]
in
if Fun.equal cr1.cstor cr2.cstor then (
(* same function: injectivity *)
assert (IArray.length cr1.args = IArray.length cr2.args);
IArray.iter2
(fun u1 u2 -> SI.CC.merge_t cc u1 u2 expl)
cr1.args cr2.args
) else (
(* different function: disjointness *)
SI.CC.raise_conflict_from_expl cc acts expl
let merge cc n1 v1 n2 v2 : _ result =
Log.debugf 5
(fun k->k "(@[%s.merge@ @[:c1 %a (t %a)@]@ @[:c2 %a (t %a)@]@])"
name N.pp n1 T.pp v1.t N.pp n2 T.pp v2.t);
(* build full explanation of why the constructor terms are equal *)
let expl =
Expl.mk_list [
Expl.mk_merge n1 v1.n;
Expl.mk_merge n2 v2.n;
]
in
if Fun.equal v1.cstor v2.cstor then (
(* same function: injectivity *)
assert (IArray.length v1.args = IArray.length v2.args);
IArray.iter2
(fun u1 u2 -> SI.CC.merge_t cc u1 u2 expl)
v1.args v2.args;
Ok v1
) else (
(* different function: disjointness *)
Error expl
)
| None, Some cr ->
N_tbl.add self.cstors n1 cr
| Some _, None -> () (* already there on the left *)
| None, None -> ()
end
end
module ST = Sidekick_core.Monoid_of_repr(Monoid)
type t = ST.t
let push_level = ST.push_level
let pop_levels = ST.pop_levels
let create_and_setup (solver:SI.t) : t =
let self = {
cstors=N_tbl.create ~size:32 ();
} in
Log.debug 1 "(setup :th-cstor)";
SI.on_cc_new_term solver (on_new_term self);
SI.on_cc_pre_merge solver (on_pre_merge self);
let self = ST.create_and_setup ~size:32 solver in
self
let theory =

View file

@ -138,19 +138,57 @@ module Make(A : ARG) : S with module A = A = struct
module Card = Compute_card(A)
type cstor_repr = {
t: T.t;
n: N.t;
cstor: A.Cstor.t;
args: T.t IArray.t;
}
(* associate to each class a unique constructor term in the class (if any) *)
module Monoid_cstor = struct
module SI = SI
(* associate to each class a unique constructor term in the class (if any) *)
type t = {
t: T.t;
n: N.t;
cstor: A.Cstor.t;
args: T.t IArray.t;
}
let name = "th-data.cstor"
let pp out (v:t) =
Fmt.fprintf out "(@[cstor %a@ :term %a@])" A.Cstor.pp v.cstor T.pp v.t
(* attach data to constructor terms *)
let of_term n (t:T.t) : _ option =
match A.view_as_data t with
| T_cstor (cstor,args) -> Some {n; t; cstor; args}
| _ -> None
let merge cc n1 v1 n2 v2 : _ result =
Log.debugf 5
(fun k->k "(@[%s.merge@ @[:c1 %a (t %a)@]@ @[:c2 %a (t %a)@]@])"
name N.pp n1 T.pp v1.t N.pp n2 T.pp v2.t);
(* build full explanation of why the constructor terms are equal *)
let expl =
Expl.mk_list [
Expl.mk_merge n1 v1.n;
Expl.mk_merge n2 v2.n;
]
in
if A.Cstor.equal v1.cstor v2.cstor then (
(* same function: injectivity *)
assert (IArray.length v1.args = IArray.length v2.args);
IArray.iter2
(fun u1 u2 -> SI.CC.merge_t cc u1 u2 expl)
v1.args v2.args;
Ok v1
) else (
(* different function: disjointness *)
Error expl
)
end
module ST_cstors = Sidekick_core.Monoid_of_repr(Monoid_cstor)
module N_tbl = Backtrackable_tbl.Make(N)
type t = {
tst: T.state;
cstors: cstor_repr N_tbl.t; (* repr -> cstor for the class *)
cstors: ST_cstors.t; (* repr -> cstor for the class *)
cards: Card.t; (* remember finiteness *)
to_decide: unit N_tbl.t; (* set of terms to decide. *)
case_split_done: unit T.Tbl.t; (* set of terms for which case split is done *)
@ -159,18 +197,19 @@ module Make(A : ARG) : S with module A = A = struct
}
let push_level self =
N_tbl.push_level self.cstors;
ST_cstors.push_level self.cstors;
N_tbl.push_level self.to_decide;
()
let pop_levels self n =
N_tbl.pop_levels self.cstors n;
ST_cstors.pop_levels self.cstors n;
N_tbl.pop_levels self.to_decide n;
()
(* TODO: select/is-a *)
(* TODO: acyclicity *)
(* TODO: remove
(* attach data to constructor terms *)
let on_new_term_look_at_shape self n (t:T.t) =
match A.view_as_data t with
@ -193,6 +232,7 @@ module Make(A : ARG) : S with module A = A = struct
()
(* N_tbl.add self.cstors n {n; t; cstor; args}; *)
| T_other _ -> ()
*)
(* remember terms of a datatype *)
let on_new_term_look_at_ty (self:t) n (t:T.t) : unit =
@ -211,40 +251,9 @@ module Make(A : ARG) : S with module A = A = struct
| _ -> ()
let on_new_term self _solver n t =
on_new_term_look_at_shape self n t;
on_new_term_look_at_ty self n t;
()
let on_pre_merge (self:t) cc acts n1 n2 e_n1_n2 : unit =
begin match N_tbl.get self.cstors n1, N_tbl.get self.cstors n2 with
| Some cr1, Some cr2 ->
Log.debugf 5
(fun k->k "(@[th-cstor.on_pre_merge@ @[:c1 %a@ (term %a)@]@ @[:c2 %a@ (term %a)@]@])"
N.pp n1 T.pp cr1.t N.pp n2 T.pp cr2.t);
(* build full explanation of why the constructor terms are equal *)
let expl =
Expl.mk_list [
e_n1_n2;
Expl.mk_merge n1 cr1.n;
Expl.mk_merge n2 cr2.n;
]
in
if A.Cstor.equal cr1.cstor cr2.cstor then (
(* same function: injectivity *)
assert (IArray.length cr1.args = IArray.length cr2.args);
IArray.iter2
(fun u1 u2 -> SI.CC.merge_t cc u1 u2 expl)
cr1.args cr2.args
) else (
(* different function: disjointness *)
SI.CC.raise_conflict_from_expl cc acts expl
)
| None, Some cr ->
N_tbl.add self.cstors n1 cr
| Some _, None -> () (* already there on the left *)
| None, None -> ()
end
let cstors_of_ty (ty:Ty.t) : A.Cstor.t Iter.t =
match A.as_datatype ty with
| Ty_data {cstors} -> cstors
@ -258,7 +267,7 @@ module Make(A : ARG) : S with module A = A = struct
|> Iter.map (fun (n,_) -> SI.cc_find solver n)
|> Iter.filter
(fun n ->
not (N_tbl.mem self.cstors n) &&
not (ST_cstors.mem self.cstors n) &&
not (T.Tbl.mem self.case_split_done (N.term n)))
|> Iter.to_rev_list
in
@ -297,14 +306,13 @@ module Make(A : ARG) : S with module A = A = struct
let create_and_setup (solver:SI.t) : t =
let self = {
tst=SI.tst solver;
cstors=N_tbl.create ~size:32 ();
cstors=ST_cstors.create_and_setup ~size:32 solver;
to_decide=N_tbl.create ~size:16 ();
case_split_done=T.Tbl.create 16;
cards=Card.create();
} in
Log.debugf 1 (fun k->k "(setup :%s)" name);
SI.on_cc_new_term solver (on_new_term self);
SI.on_cc_pre_merge solver (on_pre_merge self);
SI.on_final_check solver (on_final_check self);
self

View file

@ -0,0 +1,57 @@
(** Datatype-oriented view of terms.
['c] is the representation of constructors
['t] is the representation of terms
*)
type ('c,'t) data_view =
| T_cstor of 'c * 't IArray.t
| T_select of 'c * int * 't
| T_is_a of 'c * 't
| T_other of 't
(** View of types in a way that is directly useful for the theory of datatypes *)
type ('c, 'ty) data_ty_view =
| Ty_arrow of 'ty Iter.t * 'ty
| Ty_app of {
args: 'ty Iter.t;
}
| Ty_data of {
cstors: 'c;
}
| Ty_other
module type ARG = sig
module S : Sidekick_core.SOLVER
module Cstor : sig
type t
val ty_args : t -> S.T.Ty.t Iter.t
val pp : t Fmt.printer
val equal : t -> t -> bool
end
val as_datatype : S.T.Ty.t -> (Cstor.t Iter.t, S.T.Ty.t) data_ty_view
(** Try to view type as a datatype (with its constructors) *)
val view_as_data : S.T.Term.t -> (Cstor.t, S.T.Term.t) data_view
(** Try to view term as a datatype term *)
val mk_cstor : S.T.Term.state -> Cstor.t -> S.T.Term.t IArray.t -> S.T.Term.t
(** Make a constructor application term *)
val mk_is_a: S.T.Term.state -> Cstor.t -> S.T.Term.t -> S.T.Term.t
(** Make a [is-a] term *)
val ty_is_finite : S.T.Ty.t -> bool
(** Is the given type known to be finite? *)
val ty_set_is_finite : S.T.Ty.t -> bool -> unit
(** Modify the "finite" field (see {!ty_is_finite}) *)
end
module type S = sig
module A : ARG
val theory : A.S.theory
end
module Make(A : ARG) : S with module A = A