diff --git a/src/parser/ast_term.ml b/src/parser/ast_term.ml index 615fc572..63e40f05 100644 --- a/src/parser/ast_term.ml +++ b/src/parser/ast_term.ml @@ -15,11 +15,15 @@ and term_view = | Int of string | App of term * term list | Let of let_binding list * term - | Lambda of var list * term - | Pi of var list * term + | Lambda of ty_var_group list * term + | Pi of ty_var_group list * term | Arrow of term list * term | Error_node of string +and ty_var_group = + | VG_untyped of string + | VG_typed of { names: string list; ty: term } + and let_binding = var * term and var = { name: string; ty: term option } @@ -30,7 +34,12 @@ type decl = decl_view with_loc (* TODO: axiom *) and decl_view = - | D_def of { name: string; args: var list; ty_ret: term option; rhs: term } + | D_def of { + name: string; + args: ty_var_group list; + ty_ret: term option; + rhs: term; + } | D_hash of string * term | D_theorem of { name: string; goal: term; proof: proof } @@ -131,19 +140,34 @@ let rec pp_term out (e : term) : unit = let ppb out ((x, t) : let_binding) = Fmt.fprintf out "@[<2>%s :=@ %a@]" x.name pp t in - Fmt.fprintf out "@[@[<2>let@ @[%a@]@] in@ %a@]" - (Util.pp_list ~sep:" and " ppb) - bs pp bod + let ppbs out l = + List.iteri + (fun i x -> + if i > 0 then Fmt.fprintf out "@ and "; + ppb out x) + l + in + Fmt.fprintf out "@[@[let %a@] in@ %a@]" ppbs bs pp bod | Lambda (args, bod) -> - Fmt.fprintf out "@[lam %a.@ %a@]" (Util.pp_list pp_tyvar) args pp_sub bod + Fmt.fprintf out "@[lam%a.@ %a@]" + (Util.pp_list ~sep:"" pp_ty_var_group) + args pp_sub bod | Pi (args, bod) -> - Fmt.fprintf out "@[pi %a.@ %a@]" (Util.pp_list pp_tyvar) args pp_sub bod + Fmt.fprintf out "@[pi%a.@ %a@]" + (Util.pp_list ~sep:"" pp_ty_var_group) + args pp_sub bod and pp_tyvar out (x : var) : unit = match x.ty with | None -> Fmt.string out x.name | Some ty -> Fmt.fprintf out "(@[%s : %a@])" x.name pp_term ty +and pp_ty_var_group out (x : ty_var_group) : unit = + match x with + | VG_untyped x -> Fmt.fprintf out "@ %s" x + | VG_typed { names; ty } -> + Fmt.fprintf out "@ (@[%a : %a@])" (Util.pp_list Fmt.string) names pp_term ty + let rec pp_proof out (p : proof) : unit = match p.view with | P_by t -> Fmt.fprintf out "@[by@ %a@]" pp_term t @@ -166,7 +190,8 @@ let pp_decl out (d : decl) = | None -> () | Some ty -> Fmt.fprintf out " @[: %a@]" pp_term ty in - Fmt.fprintf out "@[<2>def %s%a%a :=@ %a@];" name (Util.pp_list pp_tyvar) + Fmt.fprintf out "@[<2>def %s%a%a :=@ %a@];" name + (Util.pp_list ~sep:"" pp_ty_var_group) args pp_tyret () pp_term rhs | D_hash (name, t) -> Fmt.fprintf out "@[<2>#%s@ %a@];" name pp_term t | D_theorem { name; goal; proof } -> diff --git a/src/parser/lex.mll b/src/parser/lex.mll index 34994989..be33a831 100644 --- a/src/parser/lex.mll +++ b/src/parser/lex.mll @@ -32,6 +32,7 @@ rule token = parse | "let" { LET } | "in" { IN } | "and" { AND } +| "def" { DEF } | "have" { HAVE } | "theorem" { THEOREM } | "by" { BY } diff --git a/src/parser/parse.mly b/src/parser/parse.mly index 8b7c382c..0e44316d 100644 --- a/src/parser/parse.mly +++ b/src/parser/parse.mly @@ -53,11 +53,11 @@ top_term: t=term EOF { t } decl: | h=HASH t=term SEMICOLON { let loc = Loc.of_lexloc $loc in + let h = String.sub h 1 (String.length h-1) in A.decl_hash ~loc h t } -| DEF name=name args=tyvars* ty_ret=optional_ty EQDEF rhs=term SEMICOLON { +| DEF name=name args=ty_var_group* ty_ret=optional_ty EQDEF rhs=term SEMICOLON { let loc = Loc.of_lexloc $loc in - let args = List.flatten args in A.decl_def ~loc name args ?ty_ret rhs } | THEOREM name=name EQDEF goal=term proof=proof SEMICOLON { @@ -88,10 +88,10 @@ proof_step: tyvar: | name=name ty=optional_ty { A.var ?ty name } -tyvars: -| name=name { [A.var name] } +ty_var_group: +| name=name { A.VG_untyped name } | LPAREN names=name+ COLON ty=term RPAREN { - List.map (fun name -> A.var ~ty name) names + A.VG_typed {names; ty} } %inline optional_ty: @@ -114,14 +114,12 @@ let_bindings: binder_term: | t=sym_term { t } -| FUNCTION vars=tyvars+ DOT rhs=binder_term { +| FUNCTION vars=ty_var_group+ DOT rhs=binder_term { let loc = Loc.of_lexloc $loc in - let vars = List.flatten vars in A.mk_lam ~loc vars rhs } -| PI vars=tyvars+ DOT rhs=binder_term { +| PI vars=ty_var_group+ DOT rhs=binder_term { let loc = Loc.of_lexloc $loc in - let vars = List.flatten vars in A.mk_pi ~loc vars rhs } diff --git a/unittest/parser/p1.expected b/unittest/parser/p1.expected index e3305f97..259ab264 100644 --- a/unittest/parser/p1.expected +++ b/unittest/parser/p1.expected @@ -3,11 +3,13 @@ loc(t1): at line 1, column 0 - at line 1, column 9 t2: let x := 1 in f (f x 2) loc(t2): at line 1, column 0 - at line 1, column 22 t3: let l := map f (list 1 2 3) in let l2 := rev l in = (rev l2) l -loc(t3): at line 1, column 1 - at line 1, column 61 +loc(t3): at line 1, column 0 - at line 1, column 60 t4: let assm := ==> (is_foo p) (= (filter p l) nil) in true loc(t4): at line 1, column 0 - at line 1, column 51 -t5: let - f := lam (x : int) (y : int) (z : bool). (= (+ x y) z) and - g := lam x. (f (f x)) in - is_g g +t5: let f := lam (x y : int) (z : bool). (= (+ x y) z) + and g := lam x. (f (f x)) in is_g g loc(t5): at line 1, column 0 - at line 1, column 84 +d1: + def f (x y : list int) : list int := if (= x 0) y whatever; + #ty f; + #sledgehammer lam x y. (= (f x y) (f x y)); diff --git a/unittest/parser/p1.ml b/unittest/parser/p1.ml index 6494c9ec..c10ce7fb 100644 --- a/unittest/parser/p1.ml +++ b/unittest/parser/p1.ml @@ -4,14 +4,13 @@ module A = Ast_term (* let () = Printexc.record_backtrace true *) -let () = Printexc.record_backtrace true let () = Printexc.register_printer (function | P.Exn_parse_error e -> Some (P.Error.to_string e) | _ -> None) -let test_str what s = +let test_term_str what s = let t = P.term_of_string s in match t with | Ok t -> @@ -20,22 +19,33 @@ let test_str what s = | Error err -> Fmt.printf "FAIL:@ error while parsing %S:@ %a@." what P.Error.pp err -let () = test_str "t1" "f (g x) y" -let () = test_str "t2" "let x:= 1 in f (f x 2)" +let () = test_term_str "t1" "f (g x) y" +let () = test_term_str "t2" "let x:= 1 in f (f x 2)" let () = - test_str "t3" - {| -let l := map f (list 1 2 3) in -let l2 := rev l in rev l2 = l - |} + test_term_str "t3" + {|let l := map f (list 1 2 3) in +let l2 := rev l in rev l2 = l|} let () = - test_str "t4" {|let assm := is_foo p ==> (filter p l = nil) in true - |} + test_term_str "t4" {|let assm := is_foo p ==> (filter p l = nil) in true|} let () = - test_str "t5" + test_term_str "t5" {|let f := fn (x y : int) (z:bool). ( x+ y) = z - and g := fn x. f (f x) in is_g g - |} + and g := fn x. f (f x) in is_g g|} + +let test_decl what s = + let t = P.decls_of_string s in + match t with + | Ok l -> + Fmt.printf "@[%s:@ %a@]@." what (Util.pp_list ~sep:"" A.pp_decl) l + | Error err -> + Fmt.printf "FAIL:@ error while parsing %S:@ %a@." what P.Error.pp err + +let () = + test_decl "d1" + {|def f (x y:list int) : list int := if (x = 0) y whatever; + #ty f; + #sledgehammer fn x y. f x y = f x y; + |}