mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-08 04:05:43 -05:00
fix(core-logic/term): make ty unfailing; fix DB bugs
This commit is contained in:
parent
dbd20c999b
commit
bfa434562e
3 changed files with 59 additions and 52 deletions
|
|
@ -26,10 +26,13 @@ let[@inline] has_fvars e = (e.flags lsr store_id_bits) land 1 == 1
|
|||
let[@inline] store_uid e : int = e.flags land store_id_mask
|
||||
let[@inline] is_closed e : bool = db_depth e == 0
|
||||
|
||||
let[@inline] ty_exn e : term =
|
||||
let[@inline] ty e : term =
|
||||
match e.ty with
|
||||
| Some x -> x
|
||||
| None -> assert false
|
||||
| T_ty t -> t
|
||||
| T_ty_delayed f ->
|
||||
let ty = f () in
|
||||
e.ty <- T_ty ty;
|
||||
ty
|
||||
|
||||
(* open an application *)
|
||||
let unfold_app (e : term) : term * term list =
|
||||
|
|
@ -45,19 +48,15 @@ let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit =
|
|||
let rec loop k ~depth names out e =
|
||||
let pp' = loop' k ~depth:(depth + 1) names in
|
||||
(match e.view with
|
||||
| E_type 0 -> Fmt.string out "type"
|
||||
| E_type i -> Fmt.fprintf out "type_%d" i
|
||||
| E_type 0 -> Fmt.string out "Type"
|
||||
| E_type i -> Fmt.fprintf out "Type(%d)" i
|
||||
| E_var v -> Fmt.string out v.v_name
|
||||
(* | E_var v -> Fmt.fprintf out "(@[%s : %a@])" v.v_name pp v.v_ty *)
|
||||
| E_bound_var v ->
|
||||
let idx = v.bv_idx in
|
||||
(match CCList.nth_opt names idx with
|
||||
| Some n when n <> "" -> Fmt.string out n
|
||||
| _ ->
|
||||
if idx < k then
|
||||
Fmt.fprintf out "x_%d" (k - idx - 1)
|
||||
else
|
||||
Fmt.fprintf out "%%db_%d" (idx - k))
|
||||
| Some n when n <> "" -> Fmt.fprintf out "%s[%d]" n idx
|
||||
| _ -> Fmt.fprintf out "_[%d]" idx)
|
||||
| E_const c -> Const.pp out c
|
||||
| (E_app _ | E_lam _) when depth > max_depth ->
|
||||
Fmt.fprintf out "@<1>…/%d" e.id
|
||||
|
|
@ -65,7 +64,7 @@ let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit =
|
|||
let f, args = unfold_app e in
|
||||
Fmt.fprintf out "%a@ %a" pp' f (Util.pp_list pp') args
|
||||
| E_lam ("", _ty, bod) ->
|
||||
Fmt.fprintf out "(@[\\x_%d:@[%a@].@ %a@])" k pp' _ty
|
||||
Fmt.fprintf out "(@[\\_:@[%a@].@ %a@])" pp' _ty
|
||||
(loop (k + 1) ~depth:(depth + 1) ("" :: names))
|
||||
bod
|
||||
| E_lam (n, _ty, bod) ->
|
||||
|
|
@ -80,7 +79,7 @@ let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit =
|
|||
(loop (k + 1) ~depth:(depth + 1) ("" :: names))
|
||||
bod
|
||||
| E_pi ("", _ty, bod) ->
|
||||
Fmt.fprintf out "(@[Pi x_%d:@[%a@].@ %a@])" k pp' _ty
|
||||
Fmt.fprintf out "(@[Pi _:@[%a@].@ %a@])" pp' _ty
|
||||
(loop (k + 1) ~depth:(depth + 1) ("" :: names))
|
||||
bod
|
||||
| E_pi (n, _ty, bod) ->
|
||||
|
|
@ -125,6 +124,8 @@ module Hcons = Hashcons.Make (struct
|
|||
| E_app (f1, a1), E_app (f2, a2) -> equal f1 f2 && equal a1 a2
|
||||
| E_lam (_, ty1, bod1), E_lam (_, ty2, bod2) ->
|
||||
equal ty1 ty2 && equal bod1 bod2
|
||||
| E_pi (_, ty1, bod1), E_pi (_, ty2, bod2) ->
|
||||
equal ty1 ty2 && equal bod1 bod2
|
||||
| ( ( E_type _ | E_const _ | E_var _ | E_bound_var _ | E_app _ | E_lam _
|
||||
| E_pi _ ),
|
||||
_ ) ->
|
||||
|
|
@ -168,9 +169,7 @@ let iter_shallow ~f (e : term) : unit =
|
|||
match e.view with
|
||||
| E_type _ -> ()
|
||||
| _ ->
|
||||
(match e.ty with
|
||||
| None -> (* should be computed at build time *) assert false
|
||||
| Some ty -> f false ty);
|
||||
f false (ty e);
|
||||
(match e.view with
|
||||
| E_const _ -> ()
|
||||
| E_type _ -> assert false
|
||||
|
|
@ -284,13 +283,11 @@ let[@inline] is_type_ e =
|
|||
| E_type _ -> true
|
||||
| _ -> false
|
||||
|
||||
let[@inline] is_a_type e = is_type_ e || is_type_ (ty_exn e)
|
||||
|
||||
let iter_dag ?(seen = Tbl.create 8) ~iter_ty ~f e : unit =
|
||||
let rec loop e =
|
||||
if not (Tbl.mem seen e) then (
|
||||
Tbl.add seen e ();
|
||||
if iter_ty && not (is_type_ e) then loop (ty_exn e);
|
||||
if iter_ty && not (is_type_ e) then loop (ty e);
|
||||
f e;
|
||||
iter_shallow e ~f:(fun _ u -> loop u)
|
||||
)
|
||||
|
|
@ -335,7 +332,7 @@ module Make_ = struct
|
|||
if is_type_ e then
|
||||
0
|
||||
else (
|
||||
let d1 = db_depth @@ ty_exn e in
|
||||
let d1 = db_depth @@ ty e in
|
||||
let d2 =
|
||||
match view e with
|
||||
| E_type _ | E_const _ | E_var _ -> 0
|
||||
|
|
@ -351,7 +348,7 @@ module Make_ = struct
|
|||
if is_type_ e then
|
||||
false
|
||||
else
|
||||
has_fvars (ty_exn e)
|
||||
has_fvars (ty e)
|
||||
||
|
||||
match view e with
|
||||
| E_var _ -> true
|
||||
|
|
@ -367,7 +364,7 @@ module Make_ = struct
|
|||
let[@inline] universe_of_ty_ (e : term) : int =
|
||||
match e.view with
|
||||
| E_type i -> i + 1
|
||||
| _ -> universe_ (ty_exn e)
|
||||
| _ -> universe_ (ty e)
|
||||
|
||||
module T_int_tbl = CCHashtbl.Make (struct
|
||||
type t = term * int
|
||||
|
|
@ -376,11 +373,12 @@ module Make_ = struct
|
|||
let hash (t, k) = H.combine3 27 (hash t) (H.int k)
|
||||
end)
|
||||
|
||||
(* shift open bound variables of [e] by [n] *)
|
||||
let db_shift_ ~make (e : term) (n : int) =
|
||||
let rec loop e k : term =
|
||||
if is_closed e then
|
||||
e
|
||||
else if is_a_type e then
|
||||
else if is_type_ e then
|
||||
e
|
||||
else (
|
||||
match view e with
|
||||
|
|
@ -408,8 +406,10 @@ module Make_ = struct
|
|||
let db_0_replace_ ~make e ~by:u : term =
|
||||
let cache_ = T_int_tbl.create 8 in
|
||||
|
||||
(* recurse in subterm [e], under [k] intermediate binders (so any
|
||||
bound variable under k is bound by them) *)
|
||||
let rec aux e k : term =
|
||||
if is_a_type e then
|
||||
if is_type_ e then
|
||||
e
|
||||
else if db_depth e < k then
|
||||
e
|
||||
|
|
@ -417,7 +417,8 @@ module Make_ = struct
|
|||
match view e with
|
||||
| E_const _ -> e
|
||||
| E_bound_var bv when bv.bv_idx = k ->
|
||||
(* replace here *)
|
||||
(* replace [bv] with [u], and shift [u] to account for the
|
||||
[k] intermediate binders we traversed to get to [bv] *)
|
||||
db_shift_ ~make u k
|
||||
| _ ->
|
||||
(* use the cache *)
|
||||
|
|
@ -485,24 +486,30 @@ module Make_ = struct
|
|||
| E_var v -> Var.ty v
|
||||
| E_bound_var v -> Bvar.ty v
|
||||
| E_type i -> make (E_type (i + 1))
|
||||
| E_const c -> Const.ty c
|
||||
| E_lam (name, ty, bod) ->
|
||||
| E_const c ->
|
||||
let ty = Const.ty c in
|
||||
if not (is_closed ty) then
|
||||
Error.errorf "const %a@ cannot have a non-closed type like %a" Const.pp
|
||||
c pp_debug ty;
|
||||
ty
|
||||
| E_lam (name, ty_v, bod) ->
|
||||
(* type of [\x:tau. bod] is [pi x:tau. typeof(bod)] *)
|
||||
let ty_bod = ty_exn bod in
|
||||
make (E_pi (name, ty, ty_bod))
|
||||
let ty_bod = ty bod in
|
||||
make (E_pi (name, ty_v, ty_bod))
|
||||
| E_app (f, a) ->
|
||||
(* type of [f a], where [a:tau] and [f: Pi x:tau. ty_bod_f],
|
||||
is [ty_bod_f[x := a]] *)
|
||||
let ty_f = ty_exn f in
|
||||
let ty_a = ty_exn a in
|
||||
let ty_f = ty f in
|
||||
let ty_a = ty a in
|
||||
(match ty_f.view with
|
||||
| E_pi (_, ty_arg_f, ty_bod_f) ->
|
||||
(* check that the expected type matches *)
|
||||
if not (equal ty_arg_f ty_a) then
|
||||
Error.errorf
|
||||
"@[<2>cannot apply %a to %a,@ expected argument type: %a@ actual: \
|
||||
%a@]"
|
||||
pp_debug f pp_debug a pp_debug ty_arg_f pp_debug ty_a;
|
||||
"@[<2>cannot @[apply `%a`@]@ @[to `%a`@],@ expected argument type: \
|
||||
`%a`@ @[actual: `%a`@]@]"
|
||||
pp_debug f pp_debug a pp_debug_with_ids ty_arg_f pp_debug_with_ids
|
||||
ty_a;
|
||||
db_0_replace_ ~make ty_bod_f ~by:a
|
||||
| _ ->
|
||||
Error.errorf
|
||||
|
|
@ -510,23 +517,30 @@ module Make_ = struct
|
|||
pp_debug f pp_debug ty_f)
|
||||
| E_pi (_, ty, bod) ->
|
||||
(* TODO: check the actual triplets for COC *)
|
||||
Fmt.printf "pi %a %a@." pp_debug ty pp_debug bod;
|
||||
let u = max (universe_of_ty_ ty) (universe_of_ty_ bod) + 1 in
|
||||
(*Fmt.printf "pi %a %a@." pp_debug ty pp_debug bod;*)
|
||||
let u = max (universe_of_ty_ ty) (universe_of_ty_ bod) in
|
||||
make (E_type u)
|
||||
|
||||
let ty_assert_false_ () = assert false
|
||||
|
||||
(* hashconsing + computing metadata + computing type (for new terms) *)
|
||||
let rec make_ (store : store) view : term =
|
||||
let e = { view; ty = None; id = -1; flags = 0 } in
|
||||
let e = { view; ty = T_ty_delayed ty_assert_false_; id = -1; flags = 0 } in
|
||||
let e2 = Hcons.hashcons store.s_exprs e in
|
||||
if e == e2 then (
|
||||
(* new term, compute metadata *)
|
||||
assert (store.s_uid land store_id_mask == store.s_uid);
|
||||
|
||||
(* first, compute type *)
|
||||
if not (is_type_ e) then (
|
||||
(match e.view with
|
||||
| E_type i ->
|
||||
(* cannot force type now, as it's an infinite tower of types.
|
||||
Instead we will produce the type on demand. *)
|
||||
let get_ty () = make_ store (E_type (i + 1)) in
|
||||
e.ty <- T_ty_delayed get_ty
|
||||
| _ ->
|
||||
let ty = compute_ty_ ~make:(make_ store) view in
|
||||
e.ty <- Some ty
|
||||
);
|
||||
e.ty <- T_ty ty);
|
||||
let has_fvars = compute_has_fvars_ e in
|
||||
e2.flags <-
|
||||
(compute_db_depth_ e lsl (1 + store_id_bits))
|
||||
|
|
@ -606,11 +620,6 @@ end
|
|||
|
||||
include Make_
|
||||
|
||||
let get_ty store e : term =
|
||||
match e.view with
|
||||
| E_type i -> type_of_univ store (i + 1)
|
||||
| _ -> ty_exn e
|
||||
|
||||
(* re-export some internal things *)
|
||||
module Internal_ = struct
|
||||
let subst_ store ~recursive t subst =
|
||||
|
|
|
|||
|
|
@ -65,12 +65,8 @@ val has_fvars : t -> bool
|
|||
(** Does the term contain free variables?
|
||||
time: O(1) *)
|
||||
|
||||
val ty_exn : t -> t
|
||||
(** Return the type of this term. Fails if the term is a type. *)
|
||||
|
||||
val get_ty : store -> t -> t
|
||||
(** [get_ty store t] gets the type of [t], or computes it on demand
|
||||
in case [t] is itself a type. *)
|
||||
val ty : t -> t
|
||||
(** Return the type of this term. *)
|
||||
|
||||
(** {2 Creation} *)
|
||||
|
||||
|
|
|
|||
|
|
@ -26,12 +26,14 @@ and const = { c_view: const_view; c_ops: const_ops; c_ty: term }
|
|||
and term = {
|
||||
view: term_view;
|
||||
(* computed on demand *)
|
||||
mutable ty: term option;
|
||||
mutable ty: term_ty_;
|
||||
mutable id: int;
|
||||
(* contains: [highest DB var | 1:has free vars | 5:ctx uid] *)
|
||||
mutable flags: int;
|
||||
}
|
||||
|
||||
and term_ty_ = T_ty of term | T_ty_delayed of (unit -> term)
|
||||
|
||||
module Term_ = struct
|
||||
let[@inline] equal (e1 : term) e2 : bool = e1 == e2
|
||||
let[@inline] hash (e : term) = H.int e.id
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue