mirror of
https://github.com/c-cube/ocaml-containers.git
synced 2025-12-06 03:05:28 -05:00
add CCParse.memo for memoization (changes CCParse.input)
This commit is contained in:
parent
e34e8c8116
commit
59a138ec95
2 changed files with 107 additions and 16 deletions
|
|
@ -31,6 +31,14 @@ type 'a or_error = [`Ok of 'a | `Error of string]
|
||||||
type line_num = int
|
type line_num = int
|
||||||
type col_num = int
|
type col_num = int
|
||||||
|
|
||||||
|
module H = Hashtbl.Make(struct
|
||||||
|
type t = int * int (* id of parser, position *)
|
||||||
|
let equal ((a,b):t)(c,d) = a=c && b=d
|
||||||
|
let hash = Hashtbl.hash
|
||||||
|
end)
|
||||||
|
|
||||||
|
type memo_ = (unit -> unit) H.t lazy_t
|
||||||
|
|
||||||
type input = {
|
type input = {
|
||||||
is_done : unit -> bool; (** End of input? *)
|
is_done : unit -> bool; (** End of input? *)
|
||||||
cur : unit -> char; (** Current char *)
|
cur : unit -> char; (** Current char *)
|
||||||
|
|
@ -38,43 +46,52 @@ type input = {
|
||||||
pos : unit -> int; (** Current pos *)
|
pos : unit -> int; (** Current pos *)
|
||||||
lnum : unit -> line_num; (** Line number @since NEXT_RELEASE *)
|
lnum : unit -> line_num; (** Line number @since NEXT_RELEASE *)
|
||||||
cnum : unit -> col_num; (** column number @since NEXT_RELEASE *)
|
cnum : unit -> col_num; (** column number @since NEXT_RELEASE *)
|
||||||
|
memo : memo_; (** memoization table, if any *)
|
||||||
backtrack : int -> unit; (** Restore to previous pos *)
|
backtrack : int -> unit; (** Restore to previous pos *)
|
||||||
sub : int -> int -> string; (** Extract slice from [pos] with [len] *)
|
sub : int -> int -> string; (** Extract slice from [pos] with [len] *)
|
||||||
}
|
}
|
||||||
|
|
||||||
exception ParseError of line_num * col_num * (unit -> string)
|
exception ParseError of line_num * col_num * (unit -> string)
|
||||||
|
|
||||||
(*$R
|
(*$inject
|
||||||
let module T = struct
|
module T = struct
|
||||||
type tree = L of int | N of tree * tree
|
type tree = L of int | N of tree * tree
|
||||||
end in
|
end
|
||||||
let open T in
|
open T
|
||||||
|
|
||||||
let mk_leaf x = L x in
|
let mk_leaf x = L x
|
||||||
let mk_node x y = N(x,y) in
|
let mk_node x y = N(x,y)
|
||||||
|
|
||||||
let ptree = fix @@ fun self ->
|
let ptree = fix @@ fun self ->
|
||||||
skip_space *>
|
skip_space *>
|
||||||
( (char '(' *> (pure mk_node <*> self <*> self) <* char ')')
|
( (char '(' *> (pure mk_node <*> self <*> self) <* char ')')
|
||||||
<|>
|
<|>
|
||||||
(U.int >|= mk_leaf) )
|
(U.int >|= mk_leaf) )
|
||||||
in
|
|
||||||
|
let ptree' = fix_memo @@ fun self ->
|
||||||
|
skip_space *>
|
||||||
|
( (char '(' *> (pure mk_node <*> self <*> self) <* char ')')
|
||||||
|
<|>
|
||||||
|
(U.int >|= mk_leaf) )
|
||||||
|
|
||||||
let rec pptree = function
|
let rec pptree = function
|
||||||
| N (a,b) -> Printf.sprintf "N (%s, %s)" (pptree a) (pptree b)
|
| N (a,b) -> Printf.sprintf "N (%s, %s)" (pptree a) (pptree b)
|
||||||
| L x -> Printf.sprintf "L %d" x
|
| L x -> Printf.sprintf "L %d" x
|
||||||
in
|
|
||||||
let errpptree = function
|
let errpptree = function
|
||||||
| `Ok x -> "Ok " ^ pptree x
|
| `Ok x -> "Ok " ^ pptree x
|
||||||
| `Error s -> "Error " ^ s
|
| `Error s -> "Error " ^ s
|
||||||
in
|
*)
|
||||||
|
|
||||||
assert_equal ~printer:errpptree
|
(*$= & ~printer:errpptree
|
||||||
(`Ok (N (L 1, N (L 2, L 3))))
|
(`Ok (N (L 1, N (L 2, L 3)))) \
|
||||||
(parse_string "(1 (2 3))" ptree);
|
(parse_string "(1 (2 3))" ptree)
|
||||||
assert_equal ~printer:errpptree
|
(`Ok (N (N (L 1, L 2), N (L 3, N (L 4, L 5))))) \
|
||||||
(`Ok (N (N (L 1, L 2), N (L 3, N (L 4, L 5)))))
|
(parse_string "((1 2) (3 (4 5)))" ptree)
|
||||||
(parse_string "((1 2) (3 (4 5)))" ptree);
|
(`Ok (N (L 1, N (L 2, L 3)))) \
|
||||||
|
(parse_string "(1 (2 3))" ptree' )
|
||||||
|
(`Ok (N (N (L 1, L 2), N (L 3, N (L 4, L 5))))) \
|
||||||
|
(parse_string "((1 2) (3 (4 5)))" ptree' )
|
||||||
*)
|
*)
|
||||||
|
|
||||||
(*$R
|
(*$R
|
||||||
|
|
@ -108,6 +125,7 @@ let input_of_string s =
|
||||||
);
|
);
|
||||||
lnum=(fun () -> !line);
|
lnum=(fun () -> !line);
|
||||||
cnum=(fun () -> !col);
|
cnum=(fun () -> !col);
|
||||||
|
memo=lazy (H.create 32);
|
||||||
pos=(fun () -> !i);
|
pos=(fun () -> !i);
|
||||||
backtrack=(fun j -> assert (0 <= j && j <= !i); i := j);
|
backtrack=(fun j -> assert (0 <= j && j <= !i); i := j);
|
||||||
sub=(fun j len -> assert (j + len <= !i); String.sub s j len);
|
sub=(fun j len -> assert (j + len <= !i); String.sub s j len);
|
||||||
|
|
@ -156,6 +174,7 @@ let input_of_chan ?(size=1024) ic =
|
||||||
pos=(fun() -> !i);
|
pos=(fun() -> !i);
|
||||||
lnum=(fun () -> !line);
|
lnum=(fun () -> !line);
|
||||||
cnum=(fun () -> !col);
|
cnum=(fun () -> !col);
|
||||||
|
memo=lazy (H.create 32);
|
||||||
backtrack=(fun j -> assert (0 <= j && j <= !i); i:=j);
|
backtrack=(fun j -> assert (0 <= j && j <= !i); i:=j);
|
||||||
sub=(fun j len -> assert (j + len <= !i); Bytes.sub_string !b j len);
|
sub=(fun j len -> assert (j + len <= !i); Bytes.sub_string !b j len);
|
||||||
}
|
}
|
||||||
|
|
@ -286,10 +305,58 @@ let rec sep1 ~by p =
|
||||||
and sep ~by p =
|
and sep ~by p =
|
||||||
sep1 ~by p <|> return []
|
sep1 ~by p <|> return []
|
||||||
|
|
||||||
|
module MemoTbl = struct
|
||||||
|
(* table of closures, used to implement universal type *)
|
||||||
|
type t = memo_
|
||||||
|
|
||||||
|
let create n = lazy (H.create n)
|
||||||
|
|
||||||
|
(* unique ID for each parser *)
|
||||||
|
let id_ = ref 0
|
||||||
|
|
||||||
|
type 'a res =
|
||||||
|
| Fail of exn
|
||||||
|
| Ok of 'a
|
||||||
|
end
|
||||||
|
|
||||||
let fix f =
|
let fix f =
|
||||||
let rec p st = f p st in
|
let rec p st = f p st in
|
||||||
p
|
p
|
||||||
|
|
||||||
|
let memo p =
|
||||||
|
let id = !MemoTbl.id_ in
|
||||||
|
incr MemoTbl.id_;
|
||||||
|
let r = ref None in (* used for universal encoding *)
|
||||||
|
fun input ->
|
||||||
|
let i = input.pos () in
|
||||||
|
let (lazy tbl) = input.memo in
|
||||||
|
try
|
||||||
|
let f = H.find tbl (i, id) in
|
||||||
|
(* extract hidden value *)
|
||||||
|
r := None;
|
||||||
|
f ();
|
||||||
|
begin match !r with
|
||||||
|
| None -> assert false
|
||||||
|
| Some (MemoTbl.Ok x) -> x
|
||||||
|
| Some (MemoTbl.Fail e) -> raise e
|
||||||
|
end
|
||||||
|
with Not_found ->
|
||||||
|
(* parse, and save *)
|
||||||
|
try
|
||||||
|
let x = p input in
|
||||||
|
H.replace tbl (i,id) (fun () -> r := Some (MemoTbl.Ok x));
|
||||||
|
x
|
||||||
|
with (ParseError _) as e ->
|
||||||
|
H.replace tbl (i,id) (fun () -> r := Some (MemoTbl.Fail e));
|
||||||
|
raise e
|
||||||
|
|
||||||
|
let fix_memo f =
|
||||||
|
let rec p =
|
||||||
|
let p' = lazy (memo p) in
|
||||||
|
fun st -> f (Lazy.force p') st
|
||||||
|
in
|
||||||
|
p
|
||||||
|
|
||||||
let parse_exn ~input p = p input
|
let parse_exn ~input p = p input
|
||||||
|
|
||||||
let parse ~input p =
|
let parse ~input p =
|
||||||
|
|
|
||||||
|
|
@ -68,10 +68,18 @@ type line_num = int (** @since NEXT_RELEASE *)
|
||||||
type col_num = int (** @since NEXT_RELEASE *)
|
type col_num = int (** @since NEXT_RELEASE *)
|
||||||
|
|
||||||
exception ParseError of line_num * col_num * (unit -> string)
|
exception ParseError of line_num * col_num * (unit -> string)
|
||||||
(** position * message *)
|
(** position * message
|
||||||
|
|
||||||
|
This type changed at NEXT_RELEASE *)
|
||||||
|
|
||||||
(** {2 Input} *)
|
(** {2 Input} *)
|
||||||
|
|
||||||
|
(** @since NEXT_RELEASE *)
|
||||||
|
module MemoTbl : sig
|
||||||
|
type t
|
||||||
|
val create: int -> t (** New memoization table *)
|
||||||
|
end
|
||||||
|
|
||||||
type input = {
|
type input = {
|
||||||
is_done : unit -> bool; (** End of input? *)
|
is_done : unit -> bool; (** End of input? *)
|
||||||
cur : unit -> char; (** Current char *)
|
cur : unit -> char; (** Current char *)
|
||||||
|
|
@ -83,6 +91,7 @@ type input = {
|
||||||
pos : unit -> int; (** Current pos *)
|
pos : unit -> int; (** Current pos *)
|
||||||
lnum : unit -> line_num; (** Line number @since NEXT_RELEASE *)
|
lnum : unit -> line_num; (** Line number @since NEXT_RELEASE *)
|
||||||
cnum : unit -> col_num; (** column number @since NEXT_RELEASE *)
|
cnum : unit -> col_num; (** column number @since NEXT_RELEASE *)
|
||||||
|
memo : MemoTbl.t; (** memoization table, if any *)
|
||||||
backtrack : int -> unit; (** Restore to previous pos *)
|
backtrack : int -> unit; (** Restore to previous pos *)
|
||||||
sub : int -> int -> string; (** [sub pos len] extracts slice from [pos] with [len] *)
|
sub : int -> int -> string; (** [sub pos len] extracts slice from [pos] with [len] *)
|
||||||
}
|
}
|
||||||
|
|
@ -214,6 +223,21 @@ val sep1 : by:_ t -> 'a t -> 'a list t
|
||||||
val fix : ('a t -> 'a t) -> 'a t
|
val fix : ('a t -> 'a t) -> 'a t
|
||||||
(** Fixpoint combinator *)
|
(** Fixpoint combinator *)
|
||||||
|
|
||||||
|
val memo : 'a t -> 'a t
|
||||||
|
(** Memoize the parser. [memo p] will behave like [p], but when called
|
||||||
|
in a state (read: position in input) it has already processed, [memo p]
|
||||||
|
returns a result directly. The implementation uses an underlying
|
||||||
|
hashtable.
|
||||||
|
This can be costly in memory, but improve the run time a lot if there
|
||||||
|
is a lot of backtracking involving [p].
|
||||||
|
|
||||||
|
This function is not thread-safe.
|
||||||
|
@since NEXT_RELEASE *)
|
||||||
|
|
||||||
|
val fix_memo : ('a t -> 'a t) -> 'a t
|
||||||
|
(** Same as {!fix}, but the fixpoint is memoized.
|
||||||
|
@since NEXT_RELEASE *)
|
||||||
|
|
||||||
(** {2 Parse} *)
|
(** {2 Parse} *)
|
||||||
|
|
||||||
val parse : input:input -> 'a t -> 'a or_error
|
val parse : input:input -> 'a t -> 'a or_error
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue