From 28173c18520395a23054995d52d7d3d78a07718d Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Thu, 25 Aug 2022 20:50:56 -0400 Subject: [PATCH] feat(term): replace E_app_uncurried with E_app_fold --- src/base/Form.ml | 27 ++++--- src/base/th_data.ml | 2 +- src/core-logic/term.ml | 75 ++++++++++++++----- src/core-logic/term.mli | 8 +- src/core-logic/types_.ml | 6 +- src/core/t_printer.ml | 5 +- src/th-bool-static/Sidekick_th_bool_static.ml | 9 +-- 7 files changed, 88 insertions(+), 44 deletions(-) diff --git a/src/base/Form.ml b/src/base/Form.ml index 59ea3a01..6173d171 100644 --- a/src/base/Form.ml +++ b/src/base/Form.ml @@ -56,12 +56,16 @@ let view (t : T.t) : T.t view = else B_eq (a, b) | E_const { Const.c_view = T.C_ite; _ }, [ _ty; a; b; c ] -> B_ite (a, b, c) - | E_app_uncurried { c = { Const.c_view = C_and; _ }; args; _ }, _ -> - B_and args - | E_app_uncurried { c = { Const.c_view = C_or; _ }; args; _ }, _ -> B_or args - | E_app_uncurried { c = { Const.c_view = C_imply; _ }; args = [ a; b ]; _ }, _ - -> - B_imply (a, b) + | E_const { Const.c_view = C_imply; _ }, [ a; b ] -> B_imply (a, b) + | E_app_fold { f; args; acc0 }, [] -> + (match T.view f, T.view acc0 with + | ( E_const { Const.c_view = C_and; _ }, + E_const { Const.c_view = T.C_true; _ } ) -> + B_and args + | ( E_const { Const.c_view = C_or; _ }, + E_const { Const.c_view = T.C_false; _ } ) -> + B_or args + | _ -> B_atom t) | _ -> B_atom t let ty2b_ tst = @@ -75,20 +79,19 @@ let c_imply tst : Const.t = Const.make C_imply ops ~ty:(ty2b_ tst) let and_l tst = function | [] -> T.true_ tst | [ x ] -> x - | l -> Term.app_uncurried tst (c_and tst) l ~ty:(Term.bool tst) + | l -> + Term.app_fold tst l ~f:(Term.const tst @@ c_and tst) ~acc0:(T.true_ tst) let or_l tst = function | [] -> T.false_ tst | [ x ] -> x - | l -> Term.app_uncurried tst (c_or tst) l ~ty:(Term.bool tst) + | l -> + Term.app_fold tst l ~f:(Term.const tst @@ c_or tst) ~acc0:(T.false_ tst) let bool = Term.bool_val let and_ tst a b = and_l tst [ a; b ] let or_ tst a b = or_l tst [ a; b ] - -let imply tst a b : Term.t = - Term.app_uncurried tst (c_imply tst) [ a; b ] ~ty:(Term.bool tst) - +let imply tst a b : Term.t = T.app_l tst (T.const tst @@ c_imply tst) [ a; b ] let eq = T.eq let not_ = T.not let ite = T.ite diff --git a/src/base/th_data.ml b/src/base/th_data.ml index 338589bb..1001ce21 100644 --- a/src/base/th_data.ml +++ b/src/base/th_data.ml @@ -34,7 +34,7 @@ let arg = | None, E_pi (_, a, b) -> Ty_other { sub = [ a; b ] } | ( None, ( E_const _ | E_var _ | E_type _ | E_bound_var _ | E_lam _ - | E_app_uncurried _ ) ) -> + | E_app_fold _ ) ) -> Ty_other { sub = [] } ) diff --git a/src/core-logic/term.ml b/src/core-logic/term.ml index 319eedb7..33884ee3 100644 --- a/src/core-logic/term.ml +++ b/src/core-logic/term.ml @@ -10,7 +10,11 @@ type view = term_view = | E_bound_var of bvar | E_const of const | E_app of term * term - | E_app_uncurried of { c: const; ty: term; args: term list } + | E_app_fold of { + f: term; (** function to fold *) + args: term list; (** Arguments to the fold *) + acc0: term; (** initial accumulator *) + } | E_lam of string * term * term | E_pi of string * term * term @@ -75,9 +79,10 @@ let expr_pp_with_ ~pp_ids ~max_depth out (e : term) : unit = Fmt.fprintf out "(@[\\_:@[%a@].@ %a@])" pp' _ty (loop (k + 1) ~depth:(depth + 1) ("" :: names)) bod - | E_app_uncurried { c; args; ty = _ } -> - Fmt.fprintf out "(@[%a" Const.pp c; + | E_app_fold { f; args; acc0 } -> + Fmt.fprintf out "(@[%a" pp' f; List.iter (fun x -> Fmt.fprintf out "@ %a" pp' x) args; + Fmt.fprintf out "@ %a" pp' acc0; Fmt.fprintf out "@])" | E_lam (n, _ty, bod) -> Fmt.fprintf out "(@[\\%s:@[%a@].@ %a@])" n pp' _ty @@ -128,14 +133,15 @@ module Hcons = Hashcons.Make (struct | E_var v1, E_var v2 -> Var.equal v1 v2 | E_bound_var v1, E_bound_var v2 -> Bvar.equal v1 v2 | E_app (f1, a1), E_app (f2, a2) -> equal f1 f2 && equal a1 a2 - | E_app_uncurried a1, E_app_uncurried a2 -> - Const.equal a1.c a2.c && List.equal equal a1.args a2.args + | E_app_fold a1, E_app_fold a2 -> + equal a1.f a2.f && equal a1.acc0 a2.acc0 + && List.equal equal a1.args a2.args | E_lam (_, ty1, bod1), E_lam (_, ty2, bod2) -> equal ty1 ty2 && equal bod1 bod2 | E_pi (_, ty1, bod1), E_pi (_, ty2, bod2) -> equal ty1 ty2 && equal bod1 bod2 | ( ( E_type _ | E_const _ | E_var _ | E_bound_var _ | E_app _ - | E_app_uncurried _ | E_lam _ | E_pi _ ), + | E_app_fold _ | E_lam _ | E_pi _ ), _ ) -> false @@ -146,8 +152,8 @@ module Hcons = Hashcons.Make (struct | E_var v -> H.combine2 30 (Var.hash v) | E_bound_var v -> H.combine2 40 (Bvar.hash v) | E_app (f, a) -> H.combine3 50 (hash f) (hash a) - | E_app_uncurried a -> - H.combine3 55 (Const.hash a.c) (Hash.list hash a.args) + | E_app_fold a -> + H.combine4 55 (hash a.f) (hash a.acc0) (Hash.list hash a.args) | E_lam (_, ty, bod) -> H.combine3 60 (hash ty) (hash bod) | E_pi (_, ty, bod) -> H.combine3 70 (hash ty) (hash bod) @@ -189,8 +195,9 @@ let iter_shallow ~f (e : term) : unit = | E_app (hd, a) -> f false hd; f false a - | E_app_uncurried { ty; args; _ } -> - f false ty; + | E_app_fold { f = fold_f; args; acc0 } -> + f false fold_f; + f false acc0; List.iter (fun u -> f false u) args | E_lam (_, tyv, bod) | E_pi (_, tyv, bod) -> f false tyv; @@ -218,13 +225,14 @@ let map_shallow_ ~make ~f (e : term) : term = e else make (E_app (f false hd, f false a)) - | E_app_uncurried { args = l; c; ty } -> + | E_app_fold { f = fold_f; args = l; acc0 } -> + let fold_f' = f false fold_f in let l' = List.map (fun u -> f false u) l in - let ty' = f false ty in - if equal ty ty' && CCList.equal equal l l' then + let acc0' = f false acc0 in + if equal fold_f fold_f' && equal acc0 acc0' && CCList.equal equal l l' then e else - make (E_app_uncurried { c; ty = ty'; args = l' }) + make (E_app_fold { f = fold_f'; args = l'; acc0 = acc0' }) | E_lam (n, tyv, bod) -> let tyv' = f false tyv in let bod' = f true bod in @@ -304,8 +312,9 @@ module Make_ = struct | E_type _ | E_const _ | E_var _ -> 0 | E_bound_var v -> v.bv_idx + 1 | E_app (a, b) -> max (db_depth a) (db_depth b) - | E_app_uncurried { args; _ } -> - List.fold_left (fun x u -> max x (db_depth u)) 0 args + | E_app_fold { f; acc0; args } -> + let m = max (db_depth f) (db_depth acc0) in + List.fold_left (fun x u -> max x (db_depth u)) m args | E_lam (_, ty, bod) | E_pi (_, ty, bod) -> max (db_depth ty) (max 0 (db_depth bod - 1)) in @@ -322,7 +331,8 @@ module Make_ = struct | E_var _ -> true | E_type _ | E_bound_var _ | E_const _ -> false | E_app (a, b) -> has_fvars a || has_fvars b - | E_app_uncurried { args; _ } -> List.exists has_fvars args + | E_app_fold { f; acc0; args } -> + has_fvars f || has_fvars acc0 || List.exists has_fvars args | E_lam (_, ty, bod) | E_pi (_, ty, bod) -> has_fvars ty || has_fvars bod let universe_ (e : term) : int = @@ -450,7 +460,30 @@ module Make_ = struct "@[<2>cannot apply %a@ (to %a),@ must have Pi type, but actual type \ is %a@]" pp_debug f pp_debug a pp_debug ty_f) - | E_app_uncurried { ty; _ } -> ty + | E_app_fold { args = []; _ } -> assert false + | E_app_fold { f; args = a0 :: other_args as args; acc0 } -> + Store.check_e_uid store f; + Store.check_e_uid store acc0; + List.iter (Store.check_e_uid store) args; + let ty_result = ty acc0 in + let ty_a0 = ty a0 in + (* check that all arguments have the same type *) + List.iter + (fun a' -> + let ty' = ty a' in + if not (equal ty_a0 ty') then + Error.errorf + "app_fold: arguments %a@ and %a@ have incompatible types" pp_debug + a0 pp_debug a') + other_args; + (* check that [f a0 acc0] has type [ty_result] *) + let app1 = make (E_app (make (E_app (f, a0)), acc0)) in + if not (equal (ty app1) ty_result) then + Error.errorf + "app_fold: single application `%a`@ has type `%a`,@ but should have \ + type %a" + pp_debug app1 pp_debug (ty app1) pp_debug ty_result; + ty_result | E_pi (_, ty, bod) -> (* TODO: check the actual triplets for COC *) (*Fmt.printf "pi %a %a@." pp_debug ty pp_debug bod;*) @@ -501,8 +534,10 @@ module Make_ = struct let app store f a = make_ store (E_app (f, a)) let app_l store f l = List.fold_left (app store) f l - let app_uncurried store c args ~ty : t = - make_ store (E_app_uncurried { c; args; ty }) + let app_fold store ~f ~acc0 args : t = + match args with + | [] -> acc0 + | _ -> make_ store (E_app_fold { f; acc0; args }) type cache = t T_int_tbl.t diff --git a/src/core-logic/term.mli b/src/core-logic/term.mli index c3bc06a7..9bd8ec81 100644 --- a/src/core-logic/term.mli +++ b/src/core-logic/term.mli @@ -32,7 +32,11 @@ type view = term_view = | E_bound_var of bvar | E_const of const | E_app of t * t - | E_app_uncurried of { c: const; ty: term; args: term list } + | E_app_fold of { + f: term; (** function to fold *) + args: term list; (** Arguments to the fold *) + acc0: term; (** initial accumulator *) + } | E_lam of string * t * t | E_pi of string * t * t @@ -118,7 +122,7 @@ val bvar_i : store -> int -> ty:t -> t val const : store -> const -> t val app : store -> t -> t -> t val app_l : store -> t -> t list -> t -val app_uncurried : store -> const -> t list -> ty:t -> t +val app_fold : store -> f:t -> acc0:t -> t list -> t val lam : store -> var -> t -> t val pi : store -> var -> t -> t val arrow : store -> t -> t -> t diff --git a/src/core-logic/types_.ml b/src/core-logic/types_.ml index ac2f7e55..112e4153 100644 --- a/src/core-logic/types_.ml +++ b/src/core-logic/types_.ml @@ -16,7 +16,11 @@ type term_view = | E_bound_var of bvar | E_const of const | E_app of term * term - | E_app_uncurried of { c: const; ty: term; args: term list } + | E_app_fold of { + f: term; (** function to fold *) + args: term list; (** Arguments to the fold *) + acc0: term; (** initial accumulator *) + } | E_lam of string * term * term | E_pi of string * term * term diff --git a/src/core/t_printer.ml b/src/core/t_printer.ml index 92b885ff..6edf27b9 100644 --- a/src/core/t_printer.ml +++ b/src/core/t_printer.ml @@ -49,9 +49,10 @@ let expr_pp_with_ ~max_depth ~hooks out (e : term) : unit = | E_app _ -> let f, args = unfold_app e in Fmt.fprintf out "(%a@ %a)" pp' f (Util.pp_list pp') args - | E_app_uncurried { c; args; ty = _ } -> - Fmt.fprintf out "(@[%a" Const.pp c; + | E_app_fold { f; args; acc0 } -> + Fmt.fprintf out "(@[%a" pp' f; List.iter (fun x -> Fmt.fprintf out "@ %a" pp' x) args; + Fmt.fprintf out "@ %a" pp' acc0; Fmt.fprintf out "@])" | E_lam ("", _ty, bod) -> Fmt.fprintf out "(@[\\_:@[%a@].@ %a@])" pp' _ty diff --git a/src/th-bool-static/Sidekick_th_bool_static.ml b/src/th-bool-static/Sidekick_th_bool_static.ml index bb39457e..72b7eaae 100644 --- a/src/th-bool-static/Sidekick_th_bool_static.ml +++ b/src/th-bool-static/Sidekick_th_bool_static.ml @@ -114,12 +114,9 @@ end = struct None ) | B_imply (a, b) -> - if is_false a || is_true b then - ret (T.true_ tst) - else if is_true a && is_false b then - ret (T.false_ tst) - else - None + (* always rewrite [a => b] to [¬a \/ b] *) + let u = A.mk_bool tst (B_or [ T.not tst a; b ]) in + ret u | B_ite (a, b, c) -> (* directly simplify [a] so that maybe we never will simplify one of the branches *)