fix(core-logic/term): make ty unfailing; fix DB bugs

This commit is contained in:
Simon Cruanes 2022-07-28 14:51:24 -04:00
parent dbd20c999b
commit bfa434562e
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
3 changed files with 59 additions and 52 deletions

View file

@ -26,10 +26,13 @@ let[@inline] has_fvars e = (e.flags lsr store_id_bits) land 1 == 1
let[@inline] store_uid e : int = e.flags land store_id_mask
let[@inline] is_closed e : bool = db_depth e == 0
let[@inline] ty_exn e : term =
let[@inline] ty e : term =
match e.ty with
| Some x -> x
| None -> assert false
| T_ty t -> t
| T_ty_delayed f ->
let ty = f () in
e.ty <- T_ty ty;
ty
(* open an application *)
let unfold_app (e : term) : term * term list =
@ -45,19 +48,15 @@ let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit =
let rec loop k ~depth names out e =
let pp' = loop' k ~depth:(depth + 1) names in
(match e.view with
| E_type 0 -> Fmt.string out "type"
| E_type i -> Fmt.fprintf out "type_%d" i
| E_type 0 -> Fmt.string out "Type"
| E_type i -> Fmt.fprintf out "Type(%d)" i
| E_var v -> Fmt.string out v.v_name
(* | E_var v -> Fmt.fprintf out "(@[%s : %a@])" v.v_name pp v.v_ty *)
| E_bound_var v ->
let idx = v.bv_idx in
(match CCList.nth_opt names idx with
| Some n when n <> "" -> Fmt.string out n
| _ ->
if idx < k then
Fmt.fprintf out "x_%d" (k - idx - 1)
else
Fmt.fprintf out "%%db_%d" (idx - k))
| Some n when n <> "" -> Fmt.fprintf out "%s[%d]" n idx
| _ -> Fmt.fprintf out "_[%d]" idx)
| E_const c -> Const.pp out c
| (E_app _ | E_lam _) when depth > max_depth ->
Fmt.fprintf out "@<1>…/%d" e.id
@ -65,7 +64,7 @@ let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit =
let f, args = unfold_app e in
Fmt.fprintf out "%a@ %a" pp' f (Util.pp_list pp') args
| E_lam ("", _ty, bod) ->
Fmt.fprintf out "(@[\\x_%d:@[%a@].@ %a@])" k pp' _ty
Fmt.fprintf out "(@[\\_:@[%a@].@ %a@])" pp' _ty
(loop (k + 1) ~depth:(depth + 1) ("" :: names))
bod
| E_lam (n, _ty, bod) ->
@ -80,7 +79,7 @@ let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit =
(loop (k + 1) ~depth:(depth + 1) ("" :: names))
bod
| E_pi ("", _ty, bod) ->
Fmt.fprintf out "(@[Pi x_%d:@[%a@].@ %a@])" k pp' _ty
Fmt.fprintf out "(@[Pi _:@[%a@].@ %a@])" pp' _ty
(loop (k + 1) ~depth:(depth + 1) ("" :: names))
bod
| E_pi (n, _ty, bod) ->
@ -125,6 +124,8 @@ module Hcons = Hashcons.Make (struct
| E_app (f1, a1), E_app (f2, a2) -> equal f1 f2 && equal a1 a2
| E_lam (_, ty1, bod1), E_lam (_, ty2, bod2) ->
equal ty1 ty2 && equal bod1 bod2
| E_pi (_, ty1, bod1), E_pi (_, ty2, bod2) ->
equal ty1 ty2 && equal bod1 bod2
| ( ( E_type _ | E_const _ | E_var _ | E_bound_var _ | E_app _ | E_lam _
| E_pi _ ),
_ ) ->
@ -168,9 +169,7 @@ let iter_shallow ~f (e : term) : unit =
match e.view with
| E_type _ -> ()
| _ ->
(match e.ty with
| None -> (* should be computed at build time *) assert false
| Some ty -> f false ty);
f false (ty e);
(match e.view with
| E_const _ -> ()
| E_type _ -> assert false
@ -284,13 +283,11 @@ let[@inline] is_type_ e =
| E_type _ -> true
| _ -> false
let[@inline] is_a_type e = is_type_ e || is_type_ (ty_exn e)
let iter_dag ?(seen = Tbl.create 8) ~iter_ty ~f e : unit =
let rec loop e =
if not (Tbl.mem seen e) then (
Tbl.add seen e ();
if iter_ty && not (is_type_ e) then loop (ty_exn e);
if iter_ty && not (is_type_ e) then loop (ty e);
f e;
iter_shallow e ~f:(fun _ u -> loop u)
)
@ -335,7 +332,7 @@ module Make_ = struct
if is_type_ e then
0
else (
let d1 = db_depth @@ ty_exn e in
let d1 = db_depth @@ ty e in
let d2 =
match view e with
| E_type _ | E_const _ | E_var _ -> 0
@ -351,7 +348,7 @@ module Make_ = struct
if is_type_ e then
false
else
has_fvars (ty_exn e)
has_fvars (ty e)
||
match view e with
| E_var _ -> true
@ -367,7 +364,7 @@ module Make_ = struct
let[@inline] universe_of_ty_ (e : term) : int =
match e.view with
| E_type i -> i + 1
| _ -> universe_ (ty_exn e)
| _ -> universe_ (ty e)
module T_int_tbl = CCHashtbl.Make (struct
type t = term * int
@ -376,11 +373,12 @@ module Make_ = struct
let hash (t, k) = H.combine3 27 (hash t) (H.int k)
end)
(* shift open bound variables of [e] by [n] *)
let db_shift_ ~make (e : term) (n : int) =
let rec loop e k : term =
if is_closed e then
e
else if is_a_type e then
else if is_type_ e then
e
else (
match view e with
@ -408,8 +406,10 @@ module Make_ = struct
let db_0_replace_ ~make e ~by:u : term =
let cache_ = T_int_tbl.create 8 in
(* recurse in subterm [e], under [k] intermediate binders (so any
bound variable under k is bound by them) *)
let rec aux e k : term =
if is_a_type e then
if is_type_ e then
e
else if db_depth e < k then
e
@ -417,7 +417,8 @@ module Make_ = struct
match view e with
| E_const _ -> e
| E_bound_var bv when bv.bv_idx = k ->
(* replace here *)
(* replace [bv] with [u], and shift [u] to account for the
[k] intermediate binders we traversed to get to [bv] *)
db_shift_ ~make u k
| _ ->
(* use the cache *)
@ -485,24 +486,30 @@ module Make_ = struct
| E_var v -> Var.ty v
| E_bound_var v -> Bvar.ty v
| E_type i -> make (E_type (i + 1))
| E_const c -> Const.ty c
| E_lam (name, ty, bod) ->
| E_const c ->
let ty = Const.ty c in
if not (is_closed ty) then
Error.errorf "const %a@ cannot have a non-closed type like %a" Const.pp
c pp_debug ty;
ty
| E_lam (name, ty_v, bod) ->
(* type of [\x:tau. bod] is [pi x:tau. typeof(bod)] *)
let ty_bod = ty_exn bod in
make (E_pi (name, ty, ty_bod))
let ty_bod = ty bod in
make (E_pi (name, ty_v, ty_bod))
| E_app (f, a) ->
(* type of [f a], where [a:tau] and [f: Pi x:tau. ty_bod_f],
is [ty_bod_f[x := a]] *)
let ty_f = ty_exn f in
let ty_a = ty_exn a in
let ty_f = ty f in
let ty_a = ty a in
(match ty_f.view with
| E_pi (_, ty_arg_f, ty_bod_f) ->
(* check that the expected type matches *)
if not (equal ty_arg_f ty_a) then
Error.errorf
"@[<2>cannot apply %a to %a,@ expected argument type: %a@ actual: \
%a@]"
pp_debug f pp_debug a pp_debug ty_arg_f pp_debug ty_a;
"@[<2>cannot @[apply `%a`@]@ @[to `%a`@],@ expected argument type: \
`%a`@ @[actual: `%a`@]@]"
pp_debug f pp_debug a pp_debug_with_ids ty_arg_f pp_debug_with_ids
ty_a;
db_0_replace_ ~make ty_bod_f ~by:a
| _ ->
Error.errorf
@ -510,23 +517,30 @@ module Make_ = struct
pp_debug f pp_debug ty_f)
| E_pi (_, ty, bod) ->
(* TODO: check the actual triplets for COC *)
Fmt.printf "pi %a %a@." pp_debug ty pp_debug bod;
let u = max (universe_of_ty_ ty) (universe_of_ty_ bod) + 1 in
(*Fmt.printf "pi %a %a@." pp_debug ty pp_debug bod;*)
let u = max (universe_of_ty_ ty) (universe_of_ty_ bod) in
make (E_type u)
let ty_assert_false_ () = assert false
(* hashconsing + computing metadata + computing type (for new terms) *)
let rec make_ (store : store) view : term =
let e = { view; ty = None; id = -1; flags = 0 } in
let e = { view; ty = T_ty_delayed ty_assert_false_; id = -1; flags = 0 } in
let e2 = Hcons.hashcons store.s_exprs e in
if e == e2 then (
(* new term, compute metadata *)
assert (store.s_uid land store_id_mask == store.s_uid);
(* first, compute type *)
if not (is_type_ e) then (
(match e.view with
| E_type i ->
(* cannot force type now, as it's an infinite tower of types.
Instead we will produce the type on demand. *)
let get_ty () = make_ store (E_type (i + 1)) in
e.ty <- T_ty_delayed get_ty
| _ ->
let ty = compute_ty_ ~make:(make_ store) view in
e.ty <- Some ty
);
e.ty <- T_ty ty);
let has_fvars = compute_has_fvars_ e in
e2.flags <-
(compute_db_depth_ e lsl (1 + store_id_bits))
@ -606,11 +620,6 @@ end
include Make_
let get_ty store e : term =
match e.view with
| E_type i -> type_of_univ store (i + 1)
| _ -> ty_exn e
(* re-export some internal things *)
module Internal_ = struct
let subst_ store ~recursive t subst =

View file

@ -65,12 +65,8 @@ val has_fvars : t -> bool
(** Does the term contain free variables?
time: O(1) *)
val ty_exn : t -> t
(** Return the type of this term. Fails if the term is a type. *)
val get_ty : store -> t -> t
(** [get_ty store t] gets the type of [t], or computes it on demand
in case [t] is itself a type. *)
val ty : t -> t
(** Return the type of this term. *)
(** {2 Creation} *)

View file

@ -26,12 +26,14 @@ and const = { c_view: const_view; c_ops: const_ops; c_ty: term }
and term = {
view: term_view;
(* computed on demand *)
mutable ty: term option;
mutable ty: term_ty_;
mutable id: int;
(* contains: [highest DB var | 1:has free vars | 5:ctx uid] *)
mutable flags: int;
}
and term_ty_ = T_ty of term | T_ty_delayed of (unit -> term)
module Term_ = struct
let[@inline] equal (e1 : term) e2 : bool = e1 == e2
let[@inline] hash (e : term) = H.int e.id