refactor(model build): remove redundant class stuff

This commit is contained in:
Simon Cruanes 2022-10-19 22:28:38 -04:00
parent bfab613d58
commit 9c9a6e0da5
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
3 changed files with 13 additions and 18 deletions

View file

@ -5,7 +5,7 @@ module TM = Term.Map
type t = { type t = {
tst: Term.store; tst: Term.store;
mutable m: Term.t TM.t; mutable m: Term.t TM.t;
required: (Term.t * Term.t Iter.t) Queue.t; required: Term.t Queue.t;
gensym: Gensym.t; gensym: Gensym.t;
} }
@ -18,19 +18,19 @@ let pp out (self : t) : unit =
in in
Fmt.fprintf out "(@[model-builder@ :model %a@ :q (@[%a@])@])" Fmt.fprintf out "(@[model-builder@ :model %a@ :q (@[%a@])@])"
(Util.pp_iter pp_kv) (TM.to_iter self.m) (Util.pp_iter T.pp) (Util.pp_iter pp_kv) (TM.to_iter self.m) (Util.pp_iter T.pp)
(Iter.of_queue self.required |> Iter.map fst) (Iter.of_queue self.required)
let gensym self ~pre ~ty : Term.t = Gensym.fresh_term self.gensym ~pre ty let gensym self ~pre ~ty : Term.t = Gensym.fresh_term self.gensym ~pre ty
let require_eval (self : t) t ~cls : unit = let require_eval (self : t) t : unit =
if not @@ TM.mem t self.m then Queue.push (t, cls) self.required if not @@ TM.mem t self.m then Queue.push t self.required
let[@inline] mem self t : bool = TM.mem t self.m let[@inline] mem self t : bool = TM.mem t self.m
let add (self : t) ?(subs = []) t v : unit = let add (self : t) ?(subs = []) t v : unit =
if not @@ mem self t then ( if not @@ mem self t then (
self.m <- TM.add t v self.m; self.m <- TM.add t v self.m;
List.iter (fun u -> require_eval self u ~cls:Iter.empty) subs List.iter (fun u -> require_eval self u) subs
) )
type eval_cache = Term.Internal_.cache type eval_cache = Term.Internal_.cache
@ -53,9 +53,8 @@ let eval ?(cache = Term.Internal_.create_cache 8) (self : t) (t : Term.t) =
let rec pop_required (self : t) : _ option = let rec pop_required (self : t) : _ option =
match Queue.take_opt self.required with match Queue.take_opt self.required with
| None -> None | None -> None
| Some (t, cls) when TM.mem t self.m -> | Some t when TM.mem t self.m ->
(* make sure we also map [cls] to [t]'s value *) (* make sure we also map [cls] to [t]'s value *)
cls (fun u -> add self u (TM.find t self.m));
pop_required self pop_required self
| Some pair -> Some pair | Some pair -> Some pair

View file

@ -15,7 +15,7 @@ include Sidekick_sigs.PRINT with type t := t
val create : Term.store -> t val create : Term.store -> t
val mem : t -> Term.t -> bool val mem : t -> Term.t -> bool
val require_eval : t -> Term.t -> cls:Term.t Iter.t -> unit val require_eval : t -> Term.t -> unit
(** Require that this term gets a value, and assign it to all terms (** Require that this term gets a value, and assign it to all terms
in the given class. *) in the given class. *)
@ -33,7 +33,7 @@ val create_cache : int -> eval_cache
val eval : ?cache:eval_cache -> t -> Term.t -> value val eval : ?cache:eval_cache -> t -> Term.t -> value
val eval_opt : ?cache:eval_cache -> t -> Term.t -> value option val eval_opt : ?cache:eval_cache -> t -> Term.t -> value option
val pop_required : t -> (Term.t * Term.t Iter.t) option val pop_required : t -> Term.t option
(** gives the next subterm that is required but has no value yet *) (** gives the next subterm that is required but has no value yet *)
val to_map : ?cache:eval_cache -> t -> value Term.Map.t val to_map : ?cache:eval_cache -> t -> value Term.Map.t

View file

@ -363,8 +363,7 @@ let mk_model_ (self : t) (lits : lit Iter.t) : Model.t =
Log.debugf 5 (fun k -> Log.debugf 5 (fun k ->
k "(@[model.fixpoint.require-cls@ %a@])" E_node.pp repr); k "(@[model.fixpoint.require-cls@ %a@])" E_node.pp repr);
let t = E_node.term repr in let t = E_node.term repr in
let ts = E_node.iter_class repr |> Iter.map E_node.term in MB.require_eval model t);
MB.require_eval model t ~cls:ts);
(* now for the fixpoint. This is typically where composite theories such (* now for the fixpoint. This is typically where composite theories such
as arrays and datatypes contribute their skeleton values. *) as arrays and datatypes contribute their skeleton values. *)
@ -373,10 +372,10 @@ let mk_model_ (self : t) (lits : lit Iter.t) : Model.t =
while !continue do while !continue do
match MB.pop_required model with match MB.pop_required model with
| None -> continue := false | None -> continue := false
| Some (t, _cls) when Term.is_pi (Term.ty t) -> | Some t when Term.is_pi (Term.ty t) ->
(* TODO: when we support lambdas? *) (* TODO: when we support lambdas? *)
() ()
| Some (t, cls) -> | Some t ->
(* compute a value for [t] *) (* compute a value for [t] *)
Log.debugf 5 (fun k -> Log.debugf 5 (fun k ->
k "(@[model.fixpoint.compute-for-required@ %a@])" Term.pp t); k "(@[model.fixpoint.compute-for-required@ %a@])" Term.pp t);
@ -386,10 +385,7 @@ let mk_model_ (self : t) (lits : lit Iter.t) : Model.t =
| h :: hooks -> | h :: hooks ->
(match h self model t with (match h self model t with
| None -> try_hooks_ hooks | None -> try_hooks_ hooks
| Some (v, subs) -> | Some (v, subs) -> MB.add model ~subs t v)
MB.add model ~subs t v;
cls (fun u -> MB.add model ~subs:[] u v);
())
| [] -> | [] ->
(* should not happen *) (* should not happen *)
Error.errorf "cannot build a value for term@ `%a`@ of type `%a`" Error.errorf "cannot build a value for term@ `%a`@ of type `%a`"
@ -407,7 +403,7 @@ let mk_model_ (self : t) (lits : lit Iter.t) : Model.t =
let eval (t : Term.t) : value option = let eval (t : Term.t) : value option =
try Some (Term.Map.find t map) try Some (Term.Map.find t map)
with Not_found -> with Not_found ->
MB.require_eval model t ~cls:Iter.empty; MB.require_eval model t;
compute_fixpoint (); compute_fixpoint ();
MB.eval_opt ~cache model t MB.eval_opt ~cache model t
in in