mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-08 04:05:43 -05:00
fix(core-logic/term): make ty unfailing; fix DB bugs
This commit is contained in:
parent
dbd20c999b
commit
bfa434562e
3 changed files with 59 additions and 52 deletions
|
|
@ -26,10 +26,13 @@ let[@inline] has_fvars e = (e.flags lsr store_id_bits) land 1 == 1
|
||||||
let[@inline] store_uid e : int = e.flags land store_id_mask
|
let[@inline] store_uid e : int = e.flags land store_id_mask
|
||||||
let[@inline] is_closed e : bool = db_depth e == 0
|
let[@inline] is_closed e : bool = db_depth e == 0
|
||||||
|
|
||||||
let[@inline] ty_exn e : term =
|
let[@inline] ty e : term =
|
||||||
match e.ty with
|
match e.ty with
|
||||||
| Some x -> x
|
| T_ty t -> t
|
||||||
| None -> assert false
|
| T_ty_delayed f ->
|
||||||
|
let ty = f () in
|
||||||
|
e.ty <- T_ty ty;
|
||||||
|
ty
|
||||||
|
|
||||||
(* open an application *)
|
(* open an application *)
|
||||||
let unfold_app (e : term) : term * term list =
|
let unfold_app (e : term) : term * term list =
|
||||||
|
|
@ -45,19 +48,15 @@ let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit =
|
||||||
let rec loop k ~depth names out e =
|
let rec loop k ~depth names out e =
|
||||||
let pp' = loop' k ~depth:(depth + 1) names in
|
let pp' = loop' k ~depth:(depth + 1) names in
|
||||||
(match e.view with
|
(match e.view with
|
||||||
| E_type 0 -> Fmt.string out "type"
|
| E_type 0 -> Fmt.string out "Type"
|
||||||
| E_type i -> Fmt.fprintf out "type_%d" i
|
| E_type i -> Fmt.fprintf out "Type(%d)" i
|
||||||
| E_var v -> Fmt.string out v.v_name
|
| E_var v -> Fmt.string out v.v_name
|
||||||
(* | E_var v -> Fmt.fprintf out "(@[%s : %a@])" v.v_name pp v.v_ty *)
|
(* | E_var v -> Fmt.fprintf out "(@[%s : %a@])" v.v_name pp v.v_ty *)
|
||||||
| E_bound_var v ->
|
| E_bound_var v ->
|
||||||
let idx = v.bv_idx in
|
let idx = v.bv_idx in
|
||||||
(match CCList.nth_opt names idx with
|
(match CCList.nth_opt names idx with
|
||||||
| Some n when n <> "" -> Fmt.string out n
|
| Some n when n <> "" -> Fmt.fprintf out "%s[%d]" n idx
|
||||||
| _ ->
|
| _ -> Fmt.fprintf out "_[%d]" idx)
|
||||||
if idx < k then
|
|
||||||
Fmt.fprintf out "x_%d" (k - idx - 1)
|
|
||||||
else
|
|
||||||
Fmt.fprintf out "%%db_%d" (idx - k))
|
|
||||||
| E_const c -> Const.pp out c
|
| E_const c -> Const.pp out c
|
||||||
| (E_app _ | E_lam _) when depth > max_depth ->
|
| (E_app _ | E_lam _) when depth > max_depth ->
|
||||||
Fmt.fprintf out "@<1>…/%d" e.id
|
Fmt.fprintf out "@<1>…/%d" e.id
|
||||||
|
|
@ -65,7 +64,7 @@ let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit =
|
||||||
let f, args = unfold_app e in
|
let f, args = unfold_app e in
|
||||||
Fmt.fprintf out "%a@ %a" pp' f (Util.pp_list pp') args
|
Fmt.fprintf out "%a@ %a" pp' f (Util.pp_list pp') args
|
||||||
| E_lam ("", _ty, bod) ->
|
| E_lam ("", _ty, bod) ->
|
||||||
Fmt.fprintf out "(@[\\x_%d:@[%a@].@ %a@])" k pp' _ty
|
Fmt.fprintf out "(@[\\_:@[%a@].@ %a@])" pp' _ty
|
||||||
(loop (k + 1) ~depth:(depth + 1) ("" :: names))
|
(loop (k + 1) ~depth:(depth + 1) ("" :: names))
|
||||||
bod
|
bod
|
||||||
| E_lam (n, _ty, bod) ->
|
| E_lam (n, _ty, bod) ->
|
||||||
|
|
@ -80,7 +79,7 @@ let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit =
|
||||||
(loop (k + 1) ~depth:(depth + 1) ("" :: names))
|
(loop (k + 1) ~depth:(depth + 1) ("" :: names))
|
||||||
bod
|
bod
|
||||||
| E_pi ("", _ty, bod) ->
|
| E_pi ("", _ty, bod) ->
|
||||||
Fmt.fprintf out "(@[Pi x_%d:@[%a@].@ %a@])" k pp' _ty
|
Fmt.fprintf out "(@[Pi _:@[%a@].@ %a@])" pp' _ty
|
||||||
(loop (k + 1) ~depth:(depth + 1) ("" :: names))
|
(loop (k + 1) ~depth:(depth + 1) ("" :: names))
|
||||||
bod
|
bod
|
||||||
| E_pi (n, _ty, bod) ->
|
| E_pi (n, _ty, bod) ->
|
||||||
|
|
@ -125,6 +124,8 @@ module Hcons = Hashcons.Make (struct
|
||||||
| E_app (f1, a1), E_app (f2, a2) -> equal f1 f2 && equal a1 a2
|
| E_app (f1, a1), E_app (f2, a2) -> equal f1 f2 && equal a1 a2
|
||||||
| E_lam (_, ty1, bod1), E_lam (_, ty2, bod2) ->
|
| E_lam (_, ty1, bod1), E_lam (_, ty2, bod2) ->
|
||||||
equal ty1 ty2 && equal bod1 bod2
|
equal ty1 ty2 && equal bod1 bod2
|
||||||
|
| E_pi (_, ty1, bod1), E_pi (_, ty2, bod2) ->
|
||||||
|
equal ty1 ty2 && equal bod1 bod2
|
||||||
| ( ( E_type _ | E_const _ | E_var _ | E_bound_var _ | E_app _ | E_lam _
|
| ( ( E_type _ | E_const _ | E_var _ | E_bound_var _ | E_app _ | E_lam _
|
||||||
| E_pi _ ),
|
| E_pi _ ),
|
||||||
_ ) ->
|
_ ) ->
|
||||||
|
|
@ -168,9 +169,7 @@ let iter_shallow ~f (e : term) : unit =
|
||||||
match e.view with
|
match e.view with
|
||||||
| E_type _ -> ()
|
| E_type _ -> ()
|
||||||
| _ ->
|
| _ ->
|
||||||
(match e.ty with
|
f false (ty e);
|
||||||
| None -> (* should be computed at build time *) assert false
|
|
||||||
| Some ty -> f false ty);
|
|
||||||
(match e.view with
|
(match e.view with
|
||||||
| E_const _ -> ()
|
| E_const _ -> ()
|
||||||
| E_type _ -> assert false
|
| E_type _ -> assert false
|
||||||
|
|
@ -284,13 +283,11 @@ let[@inline] is_type_ e =
|
||||||
| E_type _ -> true
|
| E_type _ -> true
|
||||||
| _ -> false
|
| _ -> false
|
||||||
|
|
||||||
let[@inline] is_a_type e = is_type_ e || is_type_ (ty_exn e)
|
|
||||||
|
|
||||||
let iter_dag ?(seen = Tbl.create 8) ~iter_ty ~f e : unit =
|
let iter_dag ?(seen = Tbl.create 8) ~iter_ty ~f e : unit =
|
||||||
let rec loop e =
|
let rec loop e =
|
||||||
if not (Tbl.mem seen e) then (
|
if not (Tbl.mem seen e) then (
|
||||||
Tbl.add seen e ();
|
Tbl.add seen e ();
|
||||||
if iter_ty && not (is_type_ e) then loop (ty_exn e);
|
if iter_ty && not (is_type_ e) then loop (ty e);
|
||||||
f e;
|
f e;
|
||||||
iter_shallow e ~f:(fun _ u -> loop u)
|
iter_shallow e ~f:(fun _ u -> loop u)
|
||||||
)
|
)
|
||||||
|
|
@ -335,7 +332,7 @@ module Make_ = struct
|
||||||
if is_type_ e then
|
if is_type_ e then
|
||||||
0
|
0
|
||||||
else (
|
else (
|
||||||
let d1 = db_depth @@ ty_exn e in
|
let d1 = db_depth @@ ty e in
|
||||||
let d2 =
|
let d2 =
|
||||||
match view e with
|
match view e with
|
||||||
| E_type _ | E_const _ | E_var _ -> 0
|
| E_type _ | E_const _ | E_var _ -> 0
|
||||||
|
|
@ -351,7 +348,7 @@ module Make_ = struct
|
||||||
if is_type_ e then
|
if is_type_ e then
|
||||||
false
|
false
|
||||||
else
|
else
|
||||||
has_fvars (ty_exn e)
|
has_fvars (ty e)
|
||||||
||
|
||
|
||||||
match view e with
|
match view e with
|
||||||
| E_var _ -> true
|
| E_var _ -> true
|
||||||
|
|
@ -367,7 +364,7 @@ module Make_ = struct
|
||||||
let[@inline] universe_of_ty_ (e : term) : int =
|
let[@inline] universe_of_ty_ (e : term) : int =
|
||||||
match e.view with
|
match e.view with
|
||||||
| E_type i -> i + 1
|
| E_type i -> i + 1
|
||||||
| _ -> universe_ (ty_exn e)
|
| _ -> universe_ (ty e)
|
||||||
|
|
||||||
module T_int_tbl = CCHashtbl.Make (struct
|
module T_int_tbl = CCHashtbl.Make (struct
|
||||||
type t = term * int
|
type t = term * int
|
||||||
|
|
@ -376,11 +373,12 @@ module Make_ = struct
|
||||||
let hash (t, k) = H.combine3 27 (hash t) (H.int k)
|
let hash (t, k) = H.combine3 27 (hash t) (H.int k)
|
||||||
end)
|
end)
|
||||||
|
|
||||||
|
(* shift open bound variables of [e] by [n] *)
|
||||||
let db_shift_ ~make (e : term) (n : int) =
|
let db_shift_ ~make (e : term) (n : int) =
|
||||||
let rec loop e k : term =
|
let rec loop e k : term =
|
||||||
if is_closed e then
|
if is_closed e then
|
||||||
e
|
e
|
||||||
else if is_a_type e then
|
else if is_type_ e then
|
||||||
e
|
e
|
||||||
else (
|
else (
|
||||||
match view e with
|
match view e with
|
||||||
|
|
@ -408,8 +406,10 @@ module Make_ = struct
|
||||||
let db_0_replace_ ~make e ~by:u : term =
|
let db_0_replace_ ~make e ~by:u : term =
|
||||||
let cache_ = T_int_tbl.create 8 in
|
let cache_ = T_int_tbl.create 8 in
|
||||||
|
|
||||||
|
(* recurse in subterm [e], under [k] intermediate binders (so any
|
||||||
|
bound variable under k is bound by them) *)
|
||||||
let rec aux e k : term =
|
let rec aux e k : term =
|
||||||
if is_a_type e then
|
if is_type_ e then
|
||||||
e
|
e
|
||||||
else if db_depth e < k then
|
else if db_depth e < k then
|
||||||
e
|
e
|
||||||
|
|
@ -417,7 +417,8 @@ module Make_ = struct
|
||||||
match view e with
|
match view e with
|
||||||
| E_const _ -> e
|
| E_const _ -> e
|
||||||
| E_bound_var bv when bv.bv_idx = k ->
|
| E_bound_var bv when bv.bv_idx = k ->
|
||||||
(* replace here *)
|
(* replace [bv] with [u], and shift [u] to account for the
|
||||||
|
[k] intermediate binders we traversed to get to [bv] *)
|
||||||
db_shift_ ~make u k
|
db_shift_ ~make u k
|
||||||
| _ ->
|
| _ ->
|
||||||
(* use the cache *)
|
(* use the cache *)
|
||||||
|
|
@ -485,24 +486,30 @@ module Make_ = struct
|
||||||
| E_var v -> Var.ty v
|
| E_var v -> Var.ty v
|
||||||
| E_bound_var v -> Bvar.ty v
|
| E_bound_var v -> Bvar.ty v
|
||||||
| E_type i -> make (E_type (i + 1))
|
| E_type i -> make (E_type (i + 1))
|
||||||
| E_const c -> Const.ty c
|
| E_const c ->
|
||||||
| E_lam (name, ty, bod) ->
|
let ty = Const.ty c in
|
||||||
|
if not (is_closed ty) then
|
||||||
|
Error.errorf "const %a@ cannot have a non-closed type like %a" Const.pp
|
||||||
|
c pp_debug ty;
|
||||||
|
ty
|
||||||
|
| E_lam (name, ty_v, bod) ->
|
||||||
(* type of [\x:tau. bod] is [pi x:tau. typeof(bod)] *)
|
(* type of [\x:tau. bod] is [pi x:tau. typeof(bod)] *)
|
||||||
let ty_bod = ty_exn bod in
|
let ty_bod = ty bod in
|
||||||
make (E_pi (name, ty, ty_bod))
|
make (E_pi (name, ty_v, ty_bod))
|
||||||
| E_app (f, a) ->
|
| E_app (f, a) ->
|
||||||
(* type of [f a], where [a:tau] and [f: Pi x:tau. ty_bod_f],
|
(* type of [f a], where [a:tau] and [f: Pi x:tau. ty_bod_f],
|
||||||
is [ty_bod_f[x := a]] *)
|
is [ty_bod_f[x := a]] *)
|
||||||
let ty_f = ty_exn f in
|
let ty_f = ty f in
|
||||||
let ty_a = ty_exn a in
|
let ty_a = ty a in
|
||||||
(match ty_f.view with
|
(match ty_f.view with
|
||||||
| E_pi (_, ty_arg_f, ty_bod_f) ->
|
| E_pi (_, ty_arg_f, ty_bod_f) ->
|
||||||
(* check that the expected type matches *)
|
(* check that the expected type matches *)
|
||||||
if not (equal ty_arg_f ty_a) then
|
if not (equal ty_arg_f ty_a) then
|
||||||
Error.errorf
|
Error.errorf
|
||||||
"@[<2>cannot apply %a to %a,@ expected argument type: %a@ actual: \
|
"@[<2>cannot @[apply `%a`@]@ @[to `%a`@],@ expected argument type: \
|
||||||
%a@]"
|
`%a`@ @[actual: `%a`@]@]"
|
||||||
pp_debug f pp_debug a pp_debug ty_arg_f pp_debug ty_a;
|
pp_debug f pp_debug a pp_debug_with_ids ty_arg_f pp_debug_with_ids
|
||||||
|
ty_a;
|
||||||
db_0_replace_ ~make ty_bod_f ~by:a
|
db_0_replace_ ~make ty_bod_f ~by:a
|
||||||
| _ ->
|
| _ ->
|
||||||
Error.errorf
|
Error.errorf
|
||||||
|
|
@ -510,23 +517,30 @@ module Make_ = struct
|
||||||
pp_debug f pp_debug ty_f)
|
pp_debug f pp_debug ty_f)
|
||||||
| E_pi (_, ty, bod) ->
|
| E_pi (_, ty, bod) ->
|
||||||
(* TODO: check the actual triplets for COC *)
|
(* TODO: check the actual triplets for COC *)
|
||||||
Fmt.printf "pi %a %a@." pp_debug ty pp_debug bod;
|
(*Fmt.printf "pi %a %a@." pp_debug ty pp_debug bod;*)
|
||||||
let u = max (universe_of_ty_ ty) (universe_of_ty_ bod) + 1 in
|
let u = max (universe_of_ty_ ty) (universe_of_ty_ bod) in
|
||||||
make (E_type u)
|
make (E_type u)
|
||||||
|
|
||||||
|
let ty_assert_false_ () = assert false
|
||||||
|
|
||||||
(* hashconsing + computing metadata + computing type (for new terms) *)
|
(* hashconsing + computing metadata + computing type (for new terms) *)
|
||||||
let rec make_ (store : store) view : term =
|
let rec make_ (store : store) view : term =
|
||||||
let e = { view; ty = None; id = -1; flags = 0 } in
|
let e = { view; ty = T_ty_delayed ty_assert_false_; id = -1; flags = 0 } in
|
||||||
let e2 = Hcons.hashcons store.s_exprs e in
|
let e2 = Hcons.hashcons store.s_exprs e in
|
||||||
if e == e2 then (
|
if e == e2 then (
|
||||||
(* new term, compute metadata *)
|
(* new term, compute metadata *)
|
||||||
assert (store.s_uid land store_id_mask == store.s_uid);
|
assert (store.s_uid land store_id_mask == store.s_uid);
|
||||||
|
|
||||||
(* first, compute type *)
|
(* first, compute type *)
|
||||||
if not (is_type_ e) then (
|
(match e.view with
|
||||||
|
| E_type i ->
|
||||||
|
(* cannot force type now, as it's an infinite tower of types.
|
||||||
|
Instead we will produce the type on demand. *)
|
||||||
|
let get_ty () = make_ store (E_type (i + 1)) in
|
||||||
|
e.ty <- T_ty_delayed get_ty
|
||||||
|
| _ ->
|
||||||
let ty = compute_ty_ ~make:(make_ store) view in
|
let ty = compute_ty_ ~make:(make_ store) view in
|
||||||
e.ty <- Some ty
|
e.ty <- T_ty ty);
|
||||||
);
|
|
||||||
let has_fvars = compute_has_fvars_ e in
|
let has_fvars = compute_has_fvars_ e in
|
||||||
e2.flags <-
|
e2.flags <-
|
||||||
(compute_db_depth_ e lsl (1 + store_id_bits))
|
(compute_db_depth_ e lsl (1 + store_id_bits))
|
||||||
|
|
@ -606,11 +620,6 @@ end
|
||||||
|
|
||||||
include Make_
|
include Make_
|
||||||
|
|
||||||
let get_ty store e : term =
|
|
||||||
match e.view with
|
|
||||||
| E_type i -> type_of_univ store (i + 1)
|
|
||||||
| _ -> ty_exn e
|
|
||||||
|
|
||||||
(* re-export some internal things *)
|
(* re-export some internal things *)
|
||||||
module Internal_ = struct
|
module Internal_ = struct
|
||||||
let subst_ store ~recursive t subst =
|
let subst_ store ~recursive t subst =
|
||||||
|
|
|
||||||
|
|
@ -65,12 +65,8 @@ val has_fvars : t -> bool
|
||||||
(** Does the term contain free variables?
|
(** Does the term contain free variables?
|
||||||
time: O(1) *)
|
time: O(1) *)
|
||||||
|
|
||||||
val ty_exn : t -> t
|
val ty : t -> t
|
||||||
(** Return the type of this term. Fails if the term is a type. *)
|
(** Return the type of this term. *)
|
||||||
|
|
||||||
val get_ty : store -> t -> t
|
|
||||||
(** [get_ty store t] gets the type of [t], or computes it on demand
|
|
||||||
in case [t] is itself a type. *)
|
|
||||||
|
|
||||||
(** {2 Creation} *)
|
(** {2 Creation} *)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,12 +26,14 @@ and const = { c_view: const_view; c_ops: const_ops; c_ty: term }
|
||||||
and term = {
|
and term = {
|
||||||
view: term_view;
|
view: term_view;
|
||||||
(* computed on demand *)
|
(* computed on demand *)
|
||||||
mutable ty: term option;
|
mutable ty: term_ty_;
|
||||||
mutable id: int;
|
mutable id: int;
|
||||||
(* contains: [highest DB var | 1:has free vars | 5:ctx uid] *)
|
(* contains: [highest DB var | 1:has free vars | 5:ctx uid] *)
|
||||||
mutable flags: int;
|
mutable flags: int;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
and term_ty_ = T_ty of term | T_ty_delayed of (unit -> term)
|
||||||
|
|
||||||
module Term_ = struct
|
module Term_ = struct
|
||||||
let[@inline] equal (e1 : term) e2 : bool = e1 == e2
|
let[@inline] equal (e1 : term) e2 : bool = e1 == e2
|
||||||
let[@inline] hash (e : term) = H.int e.id
|
let[@inline] hash (e : term) = H.int e.id
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue