feat(term): replace E_app_uncurried with E_app_fold

This commit is contained in:
Simon Cruanes 2022-08-25 20:50:56 -04:00
parent f6efc8f575
commit 28173c1852
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
7 changed files with 88 additions and 44 deletions

View file

@ -56,12 +56,16 @@ let view (t : T.t) : T.t view =
else else
B_eq (a, b) B_eq (a, b)
| E_const { Const.c_view = T.C_ite; _ }, [ _ty; a; b; c ] -> B_ite (a, b, c) | E_const { Const.c_view = T.C_ite; _ }, [ _ty; a; b; c ] -> B_ite (a, b, c)
| E_app_uncurried { c = { Const.c_view = C_and; _ }; args; _ }, _ -> | E_const { Const.c_view = C_imply; _ }, [ a; b ] -> B_imply (a, b)
B_and args | E_app_fold { f; args; acc0 }, [] ->
| E_app_uncurried { c = { Const.c_view = C_or; _ }; args; _ }, _ -> B_or args (match T.view f, T.view acc0 with
| E_app_uncurried { c = { Const.c_view = C_imply; _ }; args = [ a; b ]; _ }, _ | ( E_const { Const.c_view = C_and; _ },
-> E_const { Const.c_view = T.C_true; _ } ) ->
B_imply (a, b) B_and args
| ( E_const { Const.c_view = C_or; _ },
E_const { Const.c_view = T.C_false; _ } ) ->
B_or args
| _ -> B_atom t)
| _ -> B_atom t | _ -> B_atom t
let ty2b_ tst = let ty2b_ tst =
@ -75,20 +79,19 @@ let c_imply tst : Const.t = Const.make C_imply ops ~ty:(ty2b_ tst)
let and_l tst = function let and_l tst = function
| [] -> T.true_ tst | [] -> T.true_ tst
| [ x ] -> x | [ x ] -> x
| l -> Term.app_uncurried tst (c_and tst) l ~ty:(Term.bool tst) | l ->
Term.app_fold tst l ~f:(Term.const tst @@ c_and tst) ~acc0:(T.true_ tst)
let or_l tst = function let or_l tst = function
| [] -> T.false_ tst | [] -> T.false_ tst
| [ x ] -> x | [ x ] -> x
| l -> Term.app_uncurried tst (c_or tst) l ~ty:(Term.bool tst) | l ->
Term.app_fold tst l ~f:(Term.const tst @@ c_or tst) ~acc0:(T.false_ tst)
let bool = Term.bool_val let bool = Term.bool_val
let and_ tst a b = and_l tst [ a; b ] let and_ tst a b = and_l tst [ a; b ]
let or_ tst a b = or_l tst [ a; b ] let or_ tst a b = or_l tst [ a; b ]
let imply tst a b : Term.t = T.app_l tst (T.const tst @@ c_imply tst) [ a; b ]
let imply tst a b : Term.t =
Term.app_uncurried tst (c_imply tst) [ a; b ] ~ty:(Term.bool tst)
let eq = T.eq let eq = T.eq
let not_ = T.not let not_ = T.not
let ite = T.ite let ite = T.ite

View file

@ -34,7 +34,7 @@ let arg =
| None, E_pi (_, a, b) -> Ty_other { sub = [ a; b ] } | None, E_pi (_, a, b) -> Ty_other { sub = [ a; b ] }
| ( None, | ( None,
( E_const _ | E_var _ | E_type _ | E_bound_var _ | E_lam _ ( E_const _ | E_var _ | E_type _ | E_bound_var _ | E_lam _
| E_app_uncurried _ ) ) -> | E_app_fold _ ) ) ->
Ty_other { sub = [] } Ty_other { sub = [] }
) )

View file

@ -10,7 +10,11 @@ type view = term_view =
| E_bound_var of bvar | E_bound_var of bvar
| E_const of const | E_const of const
| E_app of term * term | E_app of term * term
| E_app_uncurried of { c: const; ty: term; args: term list } | E_app_fold of {
f: term; (** function to fold *)
args: term list; (** Arguments to the fold *)
acc0: term; (** initial accumulator *)
}
| E_lam of string * term * term | E_lam of string * term * term
| E_pi of string * term * term | E_pi of string * term * term
@ -75,9 +79,10 @@ let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit =
Fmt.fprintf out "(@[\\_:@[%a@].@ %a@])" pp' _ty Fmt.fprintf out "(@[\\_:@[%a@].@ %a@])" pp' _ty
(loop (k + 1) ~depth:(depth + 1) ("" :: names)) (loop (k + 1) ~depth:(depth + 1) ("" :: names))
bod bod
| E_app_uncurried { c; args; ty = _ } -> | E_app_fold { f; args; acc0 } ->
Fmt.fprintf out "(@[%a" Const.pp c; Fmt.fprintf out "(@[%a" pp' f;
List.iter (fun x -> Fmt.fprintf out "@ %a" pp' x) args; List.iter (fun x -> Fmt.fprintf out "@ %a" pp' x) args;
Fmt.fprintf out "@ %a" pp' acc0;
Fmt.fprintf out "@])" Fmt.fprintf out "@])"
| E_lam (n, _ty, bod) -> | E_lam (n, _ty, bod) ->
Fmt.fprintf out "(@[\\%s:@[%a@].@ %a@])" n pp' _ty Fmt.fprintf out "(@[\\%s:@[%a@].@ %a@])" n pp' _ty
@ -128,14 +133,15 @@ module Hcons = Hashcons.Make (struct
| E_var v1, E_var v2 -> Var.equal v1 v2 | E_var v1, E_var v2 -> Var.equal v1 v2
| E_bound_var v1, E_bound_var v2 -> Bvar.equal v1 v2 | E_bound_var v1, E_bound_var v2 -> Bvar.equal v1 v2
| E_app (f1, a1), E_app (f2, a2) -> equal f1 f2 && equal a1 a2 | E_app (f1, a1), E_app (f2, a2) -> equal f1 f2 && equal a1 a2
| E_app_uncurried a1, E_app_uncurried a2 -> | E_app_fold a1, E_app_fold a2 ->
Const.equal a1.c a2.c && List.equal equal a1.args a2.args equal a1.f a2.f && equal a1.acc0 a2.acc0
&& List.equal equal a1.args a2.args
| E_lam (_, ty1, bod1), E_lam (_, ty2, bod2) -> | E_lam (_, ty1, bod1), E_lam (_, ty2, bod2) ->
equal ty1 ty2 && equal bod1 bod2 equal ty1 ty2 && equal bod1 bod2
| E_pi (_, ty1, bod1), E_pi (_, ty2, bod2) -> | E_pi (_, ty1, bod1), E_pi (_, ty2, bod2) ->
equal ty1 ty2 && equal bod1 bod2 equal ty1 ty2 && equal bod1 bod2
| ( ( E_type _ | E_const _ | E_var _ | E_bound_var _ | E_app _ | ( ( E_type _ | E_const _ | E_var _ | E_bound_var _ | E_app _
| E_app_uncurried _ | E_lam _ | E_pi _ ), | E_app_fold _ | E_lam _ | E_pi _ ),
_ ) -> _ ) ->
false false
@ -146,8 +152,8 @@ module Hcons = Hashcons.Make (struct
| E_var v -> H.combine2 30 (Var.hash v) | E_var v -> H.combine2 30 (Var.hash v)
| E_bound_var v -> H.combine2 40 (Bvar.hash v) | E_bound_var v -> H.combine2 40 (Bvar.hash v)
| E_app (f, a) -> H.combine3 50 (hash f) (hash a) | E_app (f, a) -> H.combine3 50 (hash f) (hash a)
| E_app_uncurried a -> | E_app_fold a ->
H.combine3 55 (Const.hash a.c) (Hash.list hash a.args) H.combine4 55 (hash a.f) (hash a.acc0) (Hash.list hash a.args)
| E_lam (_, ty, bod) -> H.combine3 60 (hash ty) (hash bod) | E_lam (_, ty, bod) -> H.combine3 60 (hash ty) (hash bod)
| E_pi (_, ty, bod) -> H.combine3 70 (hash ty) (hash bod) | E_pi (_, ty, bod) -> H.combine3 70 (hash ty) (hash bod)
@ -189,8 +195,9 @@ let iter_shallow ~f (e : term) : unit =
| E_app (hd, a) -> | E_app (hd, a) ->
f false hd; f false hd;
f false a f false a
| E_app_uncurried { ty; args; _ } -> | E_app_fold { f = fold_f; args; acc0 } ->
f false ty; f false fold_f;
f false acc0;
List.iter (fun u -> f false u) args List.iter (fun u -> f false u) args
| E_lam (_, tyv, bod) | E_pi (_, tyv, bod) -> | E_lam (_, tyv, bod) | E_pi (_, tyv, bod) ->
f false tyv; f false tyv;
@ -218,13 +225,14 @@ let map_shallow_ ~make ~f (e : term) : term =
e e
else else
make (E_app (f false hd, f false a)) make (E_app (f false hd, f false a))
| E_app_uncurried { args = l; c; ty } -> | E_app_fold { f = fold_f; args = l; acc0 } ->
let fold_f' = f false fold_f in
let l' = List.map (fun u -> f false u) l in let l' = List.map (fun u -> f false u) l in
let ty' = f false ty in let acc0' = f false acc0 in
if equal ty ty' && CCList.equal equal l l' then if equal fold_f fold_f' && equal acc0 acc0' && CCList.equal equal l l' then
e e
else else
make (E_app_uncurried { c; ty = ty'; args = l' }) make (E_app_fold { f = fold_f'; args = l'; acc0 = acc0' })
| E_lam (n, tyv, bod) -> | E_lam (n, tyv, bod) ->
let tyv' = f false tyv in let tyv' = f false tyv in
let bod' = f true bod in let bod' = f true bod in
@ -304,8 +312,9 @@ module Make_ = struct
| E_type _ | E_const _ | E_var _ -> 0 | E_type _ | E_const _ | E_var _ -> 0
| E_bound_var v -> v.bv_idx + 1 | E_bound_var v -> v.bv_idx + 1
| E_app (a, b) -> max (db_depth a) (db_depth b) | E_app (a, b) -> max (db_depth a) (db_depth b)
| E_app_uncurried { args; _ } -> | E_app_fold { f; acc0; args } ->
List.fold_left (fun x u -> max x (db_depth u)) 0 args let m = max (db_depth f) (db_depth acc0) in
List.fold_left (fun x u -> max x (db_depth u)) m args
| E_lam (_, ty, bod) | E_pi (_, ty, bod) -> | E_lam (_, ty, bod) | E_pi (_, ty, bod) ->
max (db_depth ty) (max 0 (db_depth bod - 1)) max (db_depth ty) (max 0 (db_depth bod - 1))
in in
@ -322,7 +331,8 @@ module Make_ = struct
| E_var _ -> true | E_var _ -> true
| E_type _ | E_bound_var _ | E_const _ -> false | E_type _ | E_bound_var _ | E_const _ -> false
| E_app (a, b) -> has_fvars a || has_fvars b | E_app (a, b) -> has_fvars a || has_fvars b
| E_app_uncurried { args; _ } -> List.exists has_fvars args | E_app_fold { f; acc0; args } ->
has_fvars f || has_fvars acc0 || List.exists has_fvars args
| E_lam (_, ty, bod) | E_pi (_, ty, bod) -> has_fvars ty || has_fvars bod | E_lam (_, ty, bod) | E_pi (_, ty, bod) -> has_fvars ty || has_fvars bod
let universe_ (e : term) : int = let universe_ (e : term) : int =
@ -450,7 +460,30 @@ module Make_ = struct
"@[<2>cannot apply %a@ (to %a),@ must have Pi type, but actual type \ "@[<2>cannot apply %a@ (to %a),@ must have Pi type, but actual type \
is %a@]" is %a@]"
pp_debug f pp_debug a pp_debug ty_f) pp_debug f pp_debug a pp_debug ty_f)
| E_app_uncurried { ty; _ } -> ty | E_app_fold { args = []; _ } -> assert false
| E_app_fold { f; args = a0 :: other_args as args; acc0 } ->
Store.check_e_uid store f;
Store.check_e_uid store acc0;
List.iter (Store.check_e_uid store) args;
let ty_result = ty acc0 in
let ty_a0 = ty a0 in
(* check that all arguments have the same type *)
List.iter
(fun a' ->
let ty' = ty a' in
if not (equal ty_a0 ty') then
Error.errorf
"app_fold: arguments %a@ and %a@ have incompatible types" pp_debug
a0 pp_debug a')
other_args;
(* check that [f a0 acc0] has type [ty_result] *)
let app1 = make (E_app (make (E_app (f, a0)), acc0)) in
if not (equal (ty app1) ty_result) then
Error.errorf
"app_fold: single application `%a`@ has type `%a`,@ but should have \
type %a"
pp_debug app1 pp_debug (ty app1) pp_debug ty_result;
ty_result
| E_pi (_, ty, bod) -> | E_pi (_, ty, bod) ->
(* TODO: check the actual triplets for COC *) (* TODO: check the actual triplets for COC *)
(*Fmt.printf "pi %a %a@." pp_debug ty pp_debug bod;*) (*Fmt.printf "pi %a %a@." pp_debug ty pp_debug bod;*)
@ -501,8 +534,10 @@ module Make_ = struct
let app store f a = make_ store (E_app (f, a)) let app store f a = make_ store (E_app (f, a))
let app_l store f l = List.fold_left (app store) f l let app_l store f l = List.fold_left (app store) f l
let app_uncurried store c args ~ty : t = let app_fold store ~f ~acc0 args : t =
make_ store (E_app_uncurried { c; args; ty }) match args with
| [] -> acc0
| _ -> make_ store (E_app_fold { f; acc0; args })
type cache = t T_int_tbl.t type cache = t T_int_tbl.t

View file

@ -32,7 +32,11 @@ type view = term_view =
| E_bound_var of bvar | E_bound_var of bvar
| E_const of const | E_const of const
| E_app of t * t | E_app of t * t
| E_app_uncurried of { c: const; ty: term; args: term list } | E_app_fold of {
f: term; (** function to fold *)
args: term list; (** Arguments to the fold *)
acc0: term; (** initial accumulator *)
}
| E_lam of string * t * t | E_lam of string * t * t
| E_pi of string * t * t | E_pi of string * t * t
@ -118,7 +122,7 @@ val bvar_i : store -> int -> ty:t -> t
val const : store -> const -> t val const : store -> const -> t
val app : store -> t -> t -> t val app : store -> t -> t -> t
val app_l : store -> t -> t list -> t val app_l : store -> t -> t list -> t
val app_uncurried : store -> const -> t list -> ty:t -> t val app_fold : store -> f:t -> acc0:t -> t list -> t
val lam : store -> var -> t -> t val lam : store -> var -> t -> t
val pi : store -> var -> t -> t val pi : store -> var -> t -> t
val arrow : store -> t -> t -> t val arrow : store -> t -> t -> t

View file

@ -16,7 +16,11 @@ type term_view =
| E_bound_var of bvar | E_bound_var of bvar
| E_const of const | E_const of const
| E_app of term * term | E_app of term * term
| E_app_uncurried of { c: const; ty: term; args: term list } | E_app_fold of {
f: term; (** function to fold *)
args: term list; (** Arguments to the fold *)
acc0: term; (** initial accumulator *)
}
| E_lam of string * term * term | E_lam of string * term * term
| E_pi of string * term * term | E_pi of string * term * term

View file

@ -49,9 +49,10 @@ let expr_pp_with_ ~max_depth ~hooks out (e : term) : unit =
| E_app _ -> | E_app _ ->
let f, args = unfold_app e in let f, args = unfold_app e in
Fmt.fprintf out "(%a@ %a)" pp' f (Util.pp_list pp') args Fmt.fprintf out "(%a@ %a)" pp' f (Util.pp_list pp') args
| E_app_uncurried { c; args; ty = _ } -> | E_app_fold { f; args; acc0 } ->
Fmt.fprintf out "(@[%a" Const.pp c; Fmt.fprintf out "(@[%a" pp' f;
List.iter (fun x -> Fmt.fprintf out "@ %a" pp' x) args; List.iter (fun x -> Fmt.fprintf out "@ %a" pp' x) args;
Fmt.fprintf out "@ %a" pp' acc0;
Fmt.fprintf out "@])" Fmt.fprintf out "@])"
| E_lam ("", _ty, bod) -> | E_lam ("", _ty, bod) ->
Fmt.fprintf out "(@[\\_:@[%a@].@ %a@])" pp' _ty Fmt.fprintf out "(@[\\_:@[%a@].@ %a@])" pp' _ty

View file

@ -114,12 +114,9 @@ end = struct
None None
) )
| B_imply (a, b) -> | B_imply (a, b) ->
if is_false a || is_true b then (* always rewrite [a => b] to [¬a \/ b] *)
ret (T.true_ tst) let u = A.mk_bool tst (B_or [ T.not tst a; b ]) in
else if is_true a && is_false b then ret u
ret (T.false_ tst)
else
None
| B_ite (a, b, c) -> | B_ite (a, b, c) ->
(* directly simplify [a] so that maybe we never will simplify one (* directly simplify [a] so that maybe we never will simplify one
of the branches *) of the branches *)