diff --git a/bij.ml b/bij.ml index 6bf1eb4c..3f0bf337 100644 --- a/bij.ml +++ b/bij.ml @@ -37,7 +37,11 @@ type _ t = | Pair : 'a t * 'b t -> ('a * 'b) t | Triple : 'a t * 'b t * 'c t -> ('a * 'b * 'c) t | Map : ('a -> 'b) * ('b -> 'a) * 'b t -> 'a t - | Switch : ('a -> char) * (char * 'a t) list -> 'a t + | Switch : ('a -> char * 'a inject_branch) * (char -> 'a extract_branch) -> 'a t +and _ inject_branch = + | BranchTo : 'b t * 'b * 'a -> 'a inject_branch +and _ extract_branch = + | BranchFrom : 'b t * ('b -> 'a) -> 'a extract_branch type 'a bij = 'a t @@ -55,7 +59,7 @@ let pair a b = Pair(a,b) let triple a b c = Triple (a,b,c) let map ~inject ~extract b = Map (inject, extract, b) -let switch select l = Switch (select, l) +let switch ~inject ~extract = Switch (inject, extract) exception EOF @@ -238,8 +242,8 @@ module SexpEncode(Sink : SINK) = struct | Float, f -> Sink.write_float sink f | List bij', l -> Sink.write_char sink '('; - List.iter - (fun x -> Sink.write_char sink ' '; encode bij' x) + List.iteri + (fun i x -> (if i > 0 then Sink.write_char sink ' '); encode bij' x) l; Sink.write_char sink ')' | Many _, [] -> failwith "Bij.encode: expected non-empty list" @@ -270,15 +274,11 @@ module SexpEncode(Sink : SINK) = struct | Map (inject, _, bij'), x -> let y = inject x in encode bij' y - | Switch (select, l), x -> - let c = select x in - try - let bij' = List.assq c l in - encode bij' x - with Not_found -> - raise (EncodingError "no encoding in switch") + | Switch (inject, _), x -> + let c, BranchTo (bij', y, _) = inject x in + Sink.write_char sink c; + encode bij' y in encode bij x - end module SexpDecode(Source : SOURCE) = struct @@ -347,7 +347,12 @@ module SexpDecode(Source : SOURCE) = struct | Map (_, extract, bij') -> let x = decode bij' in extract x - | Switch (_, choices) -> decode_switch choices + | Switch (_, extract) -> + let c = cur () in + let BranchFrom (bij', convert) = extract c in + junk (); (* remove c *) + let y = decode bij' in + convert y (* translate back *) and decode_open : unit -> unit = fun () -> match cur () with | '(' -> junk () (* done *) | _ -> raise (DecodingError "expected '('") @@ -377,16 +382,6 @@ module SexpDecode(Source : SOURCE) = struct | _ -> let x = decode bij in decode_list bij (x :: l) - and decode_switch : type a. (char * a t) list -> a = fun choices -> - let c = cur () in - junk (); - let bij = - try List.assq c choices - with Not_found -> - try List.assq ' ' choices - with Not_found -> raise (DecodingError "no choice") - in - decode bij in decode bij end diff --git a/bij.mli b/bij.mli index 14a4cda9..d880985d 100644 --- a/bij.mli +++ b/bij.mli @@ -42,7 +42,14 @@ val pair : 'a t -> 'b t -> ('a * 'b) t val triple : 'a t -> 'b t -> 'c t -> ('a * 'b * 'c) t val map : inject:('a -> 'b) -> extract:('b -> 'a) -> 'b t -> 'a t -val switch : ('a -> char) -> (char * 'a t) list -> 'a t + +type _ inject_branch = + | BranchTo : 'b t * 'b * 'a -> 'a inject_branch +type _ extract_branch = + | BranchFrom : 'b t * ('b -> 'a) -> 'a extract_branch + +val switch : inject:('a -> char * 'a inject_branch) -> + extract:(char -> 'a extract_branch) -> 'a t (** discriminates based on the next character. The selection function, with type ['a -> char], is used to select a bijection depending on the value. diff --git a/tests/test_bij.ml b/tests/test_bij.ml index 791e03db..f7c2d61a 100644 --- a/tests/test_bij.ml +++ b/tests/test_bij.ml @@ -26,6 +26,32 @@ let test_intlist n () = let l' = SexpStr.of_string ~bij s in OUnit.assert_equal ~printer:pp_int_list l l' +type term = + | Const of string + | Int of int + | App of term list + +let bij_term = + let rec mk_bij () = + switch + ~inject:(fun t -> match t with + | Const s -> 'c', BranchTo (string_, s, t) + | Int i -> 'i', BranchTo (int_, i, t) + | App l -> 'a', BranchTo (list_ (mk_bij ()), l, t)) + ~extract:(function + | 'c' -> BranchFrom (string_, fun x -> Const x) + | 'i' -> BranchFrom (int_, fun x -> Int x) + | 'a' -> BranchFrom (list_ (mk_bij ()), fun l -> App l) + | _ -> raise (DecodingError "unexpected case switch")) + in mk_bij () + +let test_rec () = + let t = App [Const "foo"; App [Const "bar"; Int 1; Int 2]; Int 3; Const "hello"] in + let s = SexpStr.to_string ~bij:bij_term t in + Printf.printf "to: %s\n" s; + let t' = SexpStr.of_string ~bij:bij_term s in + OUnit.assert_equal t t' + let suite = "test_bij" >::: [ "test_int2" >:: test_int2; @@ -33,4 +59,5 @@ let suite = "test_intlist10" >:: test_intlist 10; "test_intlist100" >:: test_intlist 100; "test_intlist10_000" >:: test_intlist 10_000; + "test_rec" >:: test_rec; ]