fix(core-logic/term): make ty unfailing; fix DB bugs

This commit is contained in:
Simon Cruanes 2022-07-28 14:51:24 -04:00
parent dbd20c999b
commit bfa434562e
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
3 changed files with 59 additions and 52 deletions

View file

@ -26,10 +26,13 @@ let[@inline] has_fvars e = (e.flags lsr store_id_bits) land 1 == 1
let[@inline] store_uid e : int = e.flags land store_id_mask let[@inline] store_uid e : int = e.flags land store_id_mask
let[@inline] is_closed e : bool = db_depth e == 0 let[@inline] is_closed e : bool = db_depth e == 0
let[@inline] ty_exn e : term = let[@inline] ty e : term =
match e.ty with match e.ty with
| Some x -> x | T_ty t -> t
| None -> assert false | T_ty_delayed f ->
let ty = f () in
e.ty <- T_ty ty;
ty
(* open an application *) (* open an application *)
let unfold_app (e : term) : term * term list = let unfold_app (e : term) : term * term list =
@ -45,19 +48,15 @@ let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit =
let rec loop k ~depth names out e = let rec loop k ~depth names out e =
let pp' = loop' k ~depth:(depth + 1) names in let pp' = loop' k ~depth:(depth + 1) names in
(match e.view with (match e.view with
| E_type 0 -> Fmt.string out "type" | E_type 0 -> Fmt.string out "Type"
| E_type i -> Fmt.fprintf out "type_%d" i | E_type i -> Fmt.fprintf out "Type(%d)" i
| E_var v -> Fmt.string out v.v_name | E_var v -> Fmt.string out v.v_name
(* | E_var v -> Fmt.fprintf out "(@[%s : %a@])" v.v_name pp v.v_ty *) (* | E_var v -> Fmt.fprintf out "(@[%s : %a@])" v.v_name pp v.v_ty *)
| E_bound_var v -> | E_bound_var v ->
let idx = v.bv_idx in let idx = v.bv_idx in
(match CCList.nth_opt names idx with (match CCList.nth_opt names idx with
| Some n when n <> "" -> Fmt.string out n | Some n when n <> "" -> Fmt.fprintf out "%s[%d]" n idx
| _ -> | _ -> Fmt.fprintf out "_[%d]" idx)
if idx < k then
Fmt.fprintf out "x_%d" (k - idx - 1)
else
Fmt.fprintf out "%%db_%d" (idx - k))
| E_const c -> Const.pp out c | E_const c -> Const.pp out c
| (E_app _ | E_lam _) when depth > max_depth -> | (E_app _ | E_lam _) when depth > max_depth ->
Fmt.fprintf out "@<1>…/%d" e.id Fmt.fprintf out "@<1>…/%d" e.id
@ -65,7 +64,7 @@ let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit =
let f, args = unfold_app e in let f, args = unfold_app e in
Fmt.fprintf out "%a@ %a" pp' f (Util.pp_list pp') args Fmt.fprintf out "%a@ %a" pp' f (Util.pp_list pp') args
| E_lam ("", _ty, bod) -> | E_lam ("", _ty, bod) ->
Fmt.fprintf out "(@[\\x_%d:@[%a@].@ %a@])" k pp' _ty Fmt.fprintf out "(@[\\_:@[%a@].@ %a@])" pp' _ty
(loop (k + 1) ~depth:(depth + 1) ("" :: names)) (loop (k + 1) ~depth:(depth + 1) ("" :: names))
bod bod
| E_lam (n, _ty, bod) -> | E_lam (n, _ty, bod) ->
@ -80,7 +79,7 @@ let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit =
(loop (k + 1) ~depth:(depth + 1) ("" :: names)) (loop (k + 1) ~depth:(depth + 1) ("" :: names))
bod bod
| E_pi ("", _ty, bod) -> | E_pi ("", _ty, bod) ->
Fmt.fprintf out "(@[Pi x_%d:@[%a@].@ %a@])" k pp' _ty Fmt.fprintf out "(@[Pi _:@[%a@].@ %a@])" pp' _ty
(loop (k + 1) ~depth:(depth + 1) ("" :: names)) (loop (k + 1) ~depth:(depth + 1) ("" :: names))
bod bod
| E_pi (n, _ty, bod) -> | E_pi (n, _ty, bod) ->
@ -125,6 +124,8 @@ module Hcons = Hashcons.Make (struct
| E_app (f1, a1), E_app (f2, a2) -> equal f1 f2 && equal a1 a2 | E_app (f1, a1), E_app (f2, a2) -> equal f1 f2 && equal a1 a2
| E_lam (_, ty1, bod1), E_lam (_, ty2, bod2) -> | E_lam (_, ty1, bod1), E_lam (_, ty2, bod2) ->
equal ty1 ty2 && equal bod1 bod2 equal ty1 ty2 && equal bod1 bod2
| E_pi (_, ty1, bod1), E_pi (_, ty2, bod2) ->
equal ty1 ty2 && equal bod1 bod2
| ( ( E_type _ | E_const _ | E_var _ | E_bound_var _ | E_app _ | E_lam _ | ( ( E_type _ | E_const _ | E_var _ | E_bound_var _ | E_app _ | E_lam _
| E_pi _ ), | E_pi _ ),
_ ) -> _ ) ->
@ -168,9 +169,7 @@ let iter_shallow ~f (e : term) : unit =
match e.view with match e.view with
| E_type _ -> () | E_type _ -> ()
| _ -> | _ ->
(match e.ty with f false (ty e);
| None -> (* should be computed at build time *) assert false
| Some ty -> f false ty);
(match e.view with (match e.view with
| E_const _ -> () | E_const _ -> ()
| E_type _ -> assert false | E_type _ -> assert false
@ -284,13 +283,11 @@ let[@inline] is_type_ e =
| E_type _ -> true | E_type _ -> true
| _ -> false | _ -> false
let[@inline] is_a_type e = is_type_ e || is_type_ (ty_exn e)
let iter_dag ?(seen = Tbl.create 8) ~iter_ty ~f e : unit = let iter_dag ?(seen = Tbl.create 8) ~iter_ty ~f e : unit =
let rec loop e = let rec loop e =
if not (Tbl.mem seen e) then ( if not (Tbl.mem seen e) then (
Tbl.add seen e (); Tbl.add seen e ();
if iter_ty && not (is_type_ e) then loop (ty_exn e); if iter_ty && not (is_type_ e) then loop (ty e);
f e; f e;
iter_shallow e ~f:(fun _ u -> loop u) iter_shallow e ~f:(fun _ u -> loop u)
) )
@ -335,7 +332,7 @@ module Make_ = struct
if is_type_ e then if is_type_ e then
0 0
else ( else (
let d1 = db_depth @@ ty_exn e in let d1 = db_depth @@ ty e in
let d2 = let d2 =
match view e with match view e with
| E_type _ | E_const _ | E_var _ -> 0 | E_type _ | E_const _ | E_var _ -> 0
@ -351,7 +348,7 @@ module Make_ = struct
if is_type_ e then if is_type_ e then
false false
else else
has_fvars (ty_exn e) has_fvars (ty e)
|| ||
match view e with match view e with
| E_var _ -> true | E_var _ -> true
@ -367,7 +364,7 @@ module Make_ = struct
let[@inline] universe_of_ty_ (e : term) : int = let[@inline] universe_of_ty_ (e : term) : int =
match e.view with match e.view with
| E_type i -> i + 1 | E_type i -> i + 1
| _ -> universe_ (ty_exn e) | _ -> universe_ (ty e)
module T_int_tbl = CCHashtbl.Make (struct module T_int_tbl = CCHashtbl.Make (struct
type t = term * int type t = term * int
@ -376,11 +373,12 @@ module Make_ = struct
let hash (t, k) = H.combine3 27 (hash t) (H.int k) let hash (t, k) = H.combine3 27 (hash t) (H.int k)
end) end)
(* shift open bound variables of [e] by [n] *)
let db_shift_ ~make (e : term) (n : int) = let db_shift_ ~make (e : term) (n : int) =
let rec loop e k : term = let rec loop e k : term =
if is_closed e then if is_closed e then
e e
else if is_a_type e then else if is_type_ e then
e e
else ( else (
match view e with match view e with
@ -408,8 +406,10 @@ module Make_ = struct
let db_0_replace_ ~make e ~by:u : term = let db_0_replace_ ~make e ~by:u : term =
let cache_ = T_int_tbl.create 8 in let cache_ = T_int_tbl.create 8 in
(* recurse in subterm [e], under [k] intermediate binders (so any
bound variable under k is bound by them) *)
let rec aux e k : term = let rec aux e k : term =
if is_a_type e then if is_type_ e then
e e
else if db_depth e < k then else if db_depth e < k then
e e
@ -417,7 +417,8 @@ module Make_ = struct
match view e with match view e with
| E_const _ -> e | E_const _ -> e
| E_bound_var bv when bv.bv_idx = k -> | E_bound_var bv when bv.bv_idx = k ->
(* replace here *) (* replace [bv] with [u], and shift [u] to account for the
[k] intermediate binders we traversed to get to [bv] *)
db_shift_ ~make u k db_shift_ ~make u k
| _ -> | _ ->
(* use the cache *) (* use the cache *)
@ -485,24 +486,30 @@ module Make_ = struct
| E_var v -> Var.ty v | E_var v -> Var.ty v
| E_bound_var v -> Bvar.ty v | E_bound_var v -> Bvar.ty v
| E_type i -> make (E_type (i + 1)) | E_type i -> make (E_type (i + 1))
| E_const c -> Const.ty c | E_const c ->
| E_lam (name, ty, bod) -> let ty = Const.ty c in
if not (is_closed ty) then
Error.errorf "const %a@ cannot have a non-closed type like %a" Const.pp
c pp_debug ty;
ty
| E_lam (name, ty_v, bod) ->
(* type of [\x:tau. bod] is [pi x:tau. typeof(bod)] *) (* type of [\x:tau. bod] is [pi x:tau. typeof(bod)] *)
let ty_bod = ty_exn bod in let ty_bod = ty bod in
make (E_pi (name, ty, ty_bod)) make (E_pi (name, ty_v, ty_bod))
| E_app (f, a) -> | E_app (f, a) ->
(* type of [f a], where [a:tau] and [f: Pi x:tau. ty_bod_f], (* type of [f a], where [a:tau] and [f: Pi x:tau. ty_bod_f],
is [ty_bod_f[x := a]] *) is [ty_bod_f[x := a]] *)
let ty_f = ty_exn f in let ty_f = ty f in
let ty_a = ty_exn a in let ty_a = ty a in
(match ty_f.view with (match ty_f.view with
| E_pi (_, ty_arg_f, ty_bod_f) -> | E_pi (_, ty_arg_f, ty_bod_f) ->
(* check that the expected type matches *) (* check that the expected type matches *)
if not (equal ty_arg_f ty_a) then if not (equal ty_arg_f ty_a) then
Error.errorf Error.errorf
"@[<2>cannot apply %a to %a,@ expected argument type: %a@ actual: \ "@[<2>cannot @[apply `%a`@]@ @[to `%a`@],@ expected argument type: \
%a@]" `%a`@ @[actual: `%a`@]@]"
pp_debug f pp_debug a pp_debug ty_arg_f pp_debug ty_a; pp_debug f pp_debug a pp_debug_with_ids ty_arg_f pp_debug_with_ids
ty_a;
db_0_replace_ ~make ty_bod_f ~by:a db_0_replace_ ~make ty_bod_f ~by:a
| _ -> | _ ->
Error.errorf Error.errorf
@ -510,23 +517,30 @@ module Make_ = struct
pp_debug f pp_debug ty_f) pp_debug f pp_debug ty_f)
| E_pi (_, ty, bod) -> | E_pi (_, ty, bod) ->
(* TODO: check the actual triplets for COC *) (* TODO: check the actual triplets for COC *)
Fmt.printf "pi %a %a@." pp_debug ty pp_debug bod; (*Fmt.printf "pi %a %a@." pp_debug ty pp_debug bod;*)
let u = max (universe_of_ty_ ty) (universe_of_ty_ bod) + 1 in let u = max (universe_of_ty_ ty) (universe_of_ty_ bod) in
make (E_type u) make (E_type u)
let ty_assert_false_ () = assert false
(* hashconsing + computing metadata + computing type (for new terms) *) (* hashconsing + computing metadata + computing type (for new terms) *)
let rec make_ (store : store) view : term = let rec make_ (store : store) view : term =
let e = { view; ty = None; id = -1; flags = 0 } in let e = { view; ty = T_ty_delayed ty_assert_false_; id = -1; flags = 0 } in
let e2 = Hcons.hashcons store.s_exprs e in let e2 = Hcons.hashcons store.s_exprs e in
if e == e2 then ( if e == e2 then (
(* new term, compute metadata *) (* new term, compute metadata *)
assert (store.s_uid land store_id_mask == store.s_uid); assert (store.s_uid land store_id_mask == store.s_uid);
(* first, compute type *) (* first, compute type *)
if not (is_type_ e) then ( (match e.view with
| E_type i ->
(* cannot force type now, as it's an infinite tower of types.
Instead we will produce the type on demand. *)
let get_ty () = make_ store (E_type (i + 1)) in
e.ty <- T_ty_delayed get_ty
| _ ->
let ty = compute_ty_ ~make:(make_ store) view in let ty = compute_ty_ ~make:(make_ store) view in
e.ty <- Some ty e.ty <- T_ty ty);
);
let has_fvars = compute_has_fvars_ e in let has_fvars = compute_has_fvars_ e in
e2.flags <- e2.flags <-
(compute_db_depth_ e lsl (1 + store_id_bits)) (compute_db_depth_ e lsl (1 + store_id_bits))
@ -606,11 +620,6 @@ end
include Make_ include Make_
let get_ty store e : term =
match e.view with
| E_type i -> type_of_univ store (i + 1)
| _ -> ty_exn e
(* re-export some internal things *) (* re-export some internal things *)
module Internal_ = struct module Internal_ = struct
let subst_ store ~recursive t subst = let subst_ store ~recursive t subst =

View file

@ -65,12 +65,8 @@ val has_fvars : t -> bool
(** Does the term contain free variables? (** Does the term contain free variables?
time: O(1) *) time: O(1) *)
val ty_exn : t -> t val ty : t -> t
(** Return the type of this term. Fails if the term is a type. *) (** Return the type of this term. *)
val get_ty : store -> t -> t
(** [get_ty store t] gets the type of [t], or computes it on demand
in case [t] is itself a type. *)
(** {2 Creation} *) (** {2 Creation} *)

View file

@ -26,12 +26,14 @@ and const = { c_view: const_view; c_ops: const_ops; c_ty: term }
and term = { and term = {
view: term_view; view: term_view;
(* computed on demand *) (* computed on demand *)
mutable ty: term option; mutable ty: term_ty_;
mutable id: int; mutable id: int;
(* contains: [highest DB var | 1:has free vars | 5:ctx uid] *) (* contains: [highest DB var | 1:has free vars | 5:ctx uid] *)
mutable flags: int; mutable flags: int;
} }
and term_ty_ = T_ty of term | T_ty_delayed of (unit -> term)
module Term_ = struct module Term_ = struct
let[@inline] equal (e1 : term) e2 : bool = e1 == e2 let[@inline] equal (e1 : term) e2 : bool = e1 == e2
let[@inline] hash (e : term) = H.int e.id let[@inline] hash (e : term) = H.int e.id