diff --git a/src/core-logic/term.ml b/src/core-logic/term.ml index fe4bd9ca..18e542e7 100644 --- a/src/core-logic/term.ml +++ b/src/core-logic/term.ml @@ -26,10 +26,13 @@ let[@inline] has_fvars e = (e.flags lsr store_id_bits) land 1 == 1 let[@inline] store_uid e : int = e.flags land store_id_mask let[@inline] is_closed e : bool = db_depth e == 0 -let[@inline] ty_exn e : term = +let[@inline] ty e : term = match e.ty with - | Some x -> x - | None -> assert false + | T_ty t -> t + | T_ty_delayed f -> + let ty = f () in + e.ty <- T_ty ty; + ty (* open an application *) let unfold_app (e : term) : term * term list = @@ -45,19 +48,15 @@ let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit = let rec loop k ~depth names out e = let pp' = loop' k ~depth:(depth + 1) names in (match e.view with - | E_type 0 -> Fmt.string out "type" - | E_type i -> Fmt.fprintf out "type_%d" i + | E_type 0 -> Fmt.string out "Type" + | E_type i -> Fmt.fprintf out "Type(%d)" i | E_var v -> Fmt.string out v.v_name (* | E_var v -> Fmt.fprintf out "(@[%s : %a@])" v.v_name pp v.v_ty *) | E_bound_var v -> let idx = v.bv_idx in (match CCList.nth_opt names idx with - | Some n when n <> "" -> Fmt.string out n - | _ -> - if idx < k then - Fmt.fprintf out "x_%d" (k - idx - 1) - else - Fmt.fprintf out "%%db_%d" (idx - k)) + | Some n when n <> "" -> Fmt.fprintf out "%s[%d]" n idx + | _ -> Fmt.fprintf out "_[%d]" idx) | E_const c -> Const.pp out c | (E_app _ | E_lam _) when depth > max_depth -> Fmt.fprintf out "@<1>…/%d" e.id @@ -65,7 +64,7 @@ let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit = let f, args = unfold_app e in Fmt.fprintf out "%a@ %a" pp' f (Util.pp_list pp') args | E_lam ("", _ty, bod) -> - Fmt.fprintf out "(@[\\x_%d:@[%a@].@ %a@])" k pp' _ty + Fmt.fprintf out "(@[\\_:@[%a@].@ %a@])" pp' _ty (loop (k + 1) ~depth:(depth + 1) ("" :: names)) bod | E_lam (n, _ty, bod) -> @@ -80,7 +79,7 @@ let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit = (loop (k + 1) ~depth:(depth + 1) ("" :: names)) bod | E_pi ("", _ty, bod) -> - Fmt.fprintf out "(@[Pi x_%d:@[%a@].@ %a@])" k pp' _ty + Fmt.fprintf out "(@[Pi _:@[%a@].@ %a@])" pp' _ty (loop (k + 1) ~depth:(depth + 1) ("" :: names)) bod | E_pi (n, _ty, bod) -> @@ -125,6 +124,8 @@ module Hcons = Hashcons.Make (struct | E_app (f1, a1), E_app (f2, a2) -> equal f1 f2 && equal a1 a2 | E_lam (_, ty1, bod1), E_lam (_, ty2, bod2) -> equal ty1 ty2 && equal bod1 bod2 + | E_pi (_, ty1, bod1), E_pi (_, ty2, bod2) -> + equal ty1 ty2 && equal bod1 bod2 | ( ( E_type _ | E_const _ | E_var _ | E_bound_var _ | E_app _ | E_lam _ | E_pi _ ), _ ) -> @@ -168,9 +169,7 @@ let iter_shallow ~f (e : term) : unit = match e.view with | E_type _ -> () | _ -> - (match e.ty with - | None -> (* should be computed at build time *) assert false - | Some ty -> f false ty); + f false (ty e); (match e.view with | E_const _ -> () | E_type _ -> assert false @@ -284,13 +283,11 @@ let[@inline] is_type_ e = | E_type _ -> true | _ -> false -let[@inline] is_a_type e = is_type_ e || is_type_ (ty_exn e) - let iter_dag ?(seen = Tbl.create 8) ~iter_ty ~f e : unit = let rec loop e = if not (Tbl.mem seen e) then ( Tbl.add seen e (); - if iter_ty && not (is_type_ e) then loop (ty_exn e); + if iter_ty && not (is_type_ e) then loop (ty e); f e; iter_shallow e ~f:(fun _ u -> loop u) ) @@ -335,7 +332,7 @@ module Make_ = struct if is_type_ e then 0 else ( - let d1 = db_depth @@ ty_exn e in + let d1 = db_depth @@ ty e in let d2 = match view e with | E_type _ | E_const _ | E_var _ -> 0 @@ -351,7 +348,7 @@ module Make_ = struct if is_type_ e then false else - has_fvars (ty_exn e) + has_fvars (ty e) || match view e with | E_var _ -> true @@ -367,7 +364,7 @@ module Make_ = struct let[@inline] universe_of_ty_ (e : term) : int = match e.view with | E_type i -> i + 1 - | _ -> universe_ (ty_exn e) + | _ -> universe_ (ty e) module T_int_tbl = CCHashtbl.Make (struct type t = term * int @@ -376,11 +373,12 @@ module Make_ = struct let hash (t, k) = H.combine3 27 (hash t) (H.int k) end) + (* shift open bound variables of [e] by [n] *) let db_shift_ ~make (e : term) (n : int) = let rec loop e k : term = if is_closed e then e - else if is_a_type e then + else if is_type_ e then e else ( match view e with @@ -408,8 +406,10 @@ module Make_ = struct let db_0_replace_ ~make e ~by:u : term = let cache_ = T_int_tbl.create 8 in + (* recurse in subterm [e], under [k] intermediate binders (so any + bound variable under k is bound by them) *) let rec aux e k : term = - if is_a_type e then + if is_type_ e then e else if db_depth e < k then e @@ -417,7 +417,8 @@ module Make_ = struct match view e with | E_const _ -> e | E_bound_var bv when bv.bv_idx = k -> - (* replace here *) + (* replace [bv] with [u], and shift [u] to account for the + [k] intermediate binders we traversed to get to [bv] *) db_shift_ ~make u k | _ -> (* use the cache *) @@ -485,24 +486,30 @@ module Make_ = struct | E_var v -> Var.ty v | E_bound_var v -> Bvar.ty v | E_type i -> make (E_type (i + 1)) - | E_const c -> Const.ty c - | E_lam (name, ty, bod) -> + | E_const c -> + let ty = Const.ty c in + if not (is_closed ty) then + Error.errorf "const %a@ cannot have a non-closed type like %a" Const.pp + c pp_debug ty; + ty + | E_lam (name, ty_v, bod) -> (* type of [\x:tau. bod] is [pi x:tau. typeof(bod)] *) - let ty_bod = ty_exn bod in - make (E_pi (name, ty, ty_bod)) + let ty_bod = ty bod in + make (E_pi (name, ty_v, ty_bod)) | E_app (f, a) -> (* type of [f a], where [a:tau] and [f: Pi x:tau. ty_bod_f], is [ty_bod_f[x := a]] *) - let ty_f = ty_exn f in - let ty_a = ty_exn a in + let ty_f = ty f in + let ty_a = ty a in (match ty_f.view with | E_pi (_, ty_arg_f, ty_bod_f) -> (* check that the expected type matches *) if not (equal ty_arg_f ty_a) then Error.errorf - "@[<2>cannot apply %a to %a,@ expected argument type: %a@ actual: \ - %a@]" - pp_debug f pp_debug a pp_debug ty_arg_f pp_debug ty_a; + "@[<2>cannot @[apply `%a`@]@ @[to `%a`@],@ expected argument type: \ + `%a`@ @[actual: `%a`@]@]" + pp_debug f pp_debug a pp_debug_with_ids ty_arg_f pp_debug_with_ids + ty_a; db_0_replace_ ~make ty_bod_f ~by:a | _ -> Error.errorf @@ -510,23 +517,30 @@ module Make_ = struct pp_debug f pp_debug ty_f) | E_pi (_, ty, bod) -> (* TODO: check the actual triplets for COC *) - Fmt.printf "pi %a %a@." pp_debug ty pp_debug bod; - let u = max (universe_of_ty_ ty) (universe_of_ty_ bod) + 1 in + (*Fmt.printf "pi %a %a@." pp_debug ty pp_debug bod;*) + let u = max (universe_of_ty_ ty) (universe_of_ty_ bod) in make (E_type u) + let ty_assert_false_ () = assert false + (* hashconsing + computing metadata + computing type (for new terms) *) let rec make_ (store : store) view : term = - let e = { view; ty = None; id = -1; flags = 0 } in + let e = { view; ty = T_ty_delayed ty_assert_false_; id = -1; flags = 0 } in let e2 = Hcons.hashcons store.s_exprs e in if e == e2 then ( (* new term, compute metadata *) assert (store.s_uid land store_id_mask == store.s_uid); (* first, compute type *) - if not (is_type_ e) then ( + (match e.view with + | E_type i -> + (* cannot force type now, as it's an infinite tower of types. + Instead we will produce the type on demand. *) + let get_ty () = make_ store (E_type (i + 1)) in + e.ty <- T_ty_delayed get_ty + | _ -> let ty = compute_ty_ ~make:(make_ store) view in - e.ty <- Some ty - ); + e.ty <- T_ty ty); let has_fvars = compute_has_fvars_ e in e2.flags <- (compute_db_depth_ e lsl (1 + store_id_bits)) @@ -606,11 +620,6 @@ end include Make_ -let get_ty store e : term = - match e.view with - | E_type i -> type_of_univ store (i + 1) - | _ -> ty_exn e - (* re-export some internal things *) module Internal_ = struct let subst_ store ~recursive t subst = diff --git a/src/core-logic/term.mli b/src/core-logic/term.mli index 24d4382d..1040262a 100644 --- a/src/core-logic/term.mli +++ b/src/core-logic/term.mli @@ -65,12 +65,8 @@ val has_fvars : t -> bool (** Does the term contain free variables? time: O(1) *) -val ty_exn : t -> t -(** Return the type of this term. Fails if the term is a type. *) - -val get_ty : store -> t -> t -(** [get_ty store t] gets the type of [t], or computes it on demand - in case [t] is itself a type. *) +val ty : t -> t +(** Return the type of this term. *) (** {2 Creation} *) diff --git a/src/core-logic/types_.ml b/src/core-logic/types_.ml index f62a5922..69d6e95d 100644 --- a/src/core-logic/types_.ml +++ b/src/core-logic/types_.ml @@ -26,12 +26,14 @@ and const = { c_view: const_view; c_ops: const_ops; c_ty: term } and term = { view: term_view; (* computed on demand *) - mutable ty: term option; + mutable ty: term_ty_; mutable id: int; (* contains: [highest DB var | 1:has free vars | 5:ctx uid] *) mutable flags: int; } +and term_ty_ = T_ty of term | T_ty_delayed of (unit -> term) + module Term_ = struct let[@inline] equal (e1 : term) e2 : bool = e1 == e2 let[@inline] hash (e : term) = H.int e.id