diff --git a/src/codegen/containers_codegen.ml b/src/codegen/containers_codegen.ml new file mode 100644 index 00000000..cc390ff3 --- /dev/null +++ b/src/codegen/containers_codegen.ml @@ -0,0 +1,138 @@ + +(** {1 Code generators} *) + +module Fmt = CCFormat +let spf = Printf.sprintf +let fpf = Fmt.fprintf + +type code = + | Base of { pp: unit Fmt.printer } + | Struct of string * code list + | Sig of string * code list + +module Code = struct + type t = code + + let in_struct m (cs:t list) : t = Struct (m, cs) + let in_sig m (cs:t list) : t = Sig (m, cs) + + let rec pp_rec out c = + let ppl = Fmt.(list ~sep:(return "@ ") pp_rec) in + match c with + | Base {pp} -> pp out () + | Struct (m,cs) -> + fpf out "@[module %s = struct@ %a@;<1 -2>end@]" m ppl cs + | Sig (m,cs) -> + fpf out "@[module %s : sig@ %a@;<1 -2>end@]" m ppl cs + + let pp out c = fpf out "@[%a@]" pp_rec c + let to_string c = Fmt.to_string pp c + + let mk_pp pp = Base {pp} + let mk_str s = Base {pp=Fmt.const Fmt.string s} +end + +module Bitfield = struct + type field = { + f_name: string; + f_offset: int; + f_def: field_def; + } + and field_def = + | F_bit + | F_int of {width: int} + + type t = { + name: string; + mutable fields: field list; + mutable width: int; + emit_failure_if_too_wide: bool; + } + + let make ?(emit_failure_if_too_wide=true) ~name () : t = + { name; fields=[]; width=0; emit_failure_if_too_wide; } + + let total_width self = self.width + + let field_bit self f_name = + let f_offset = total_width self in + let f = {f_name; f_offset; f_def=F_bit} in + self.fields <- f :: self.fields; + self.width <- 1 + self.width + + let field_int self ~width f_name : unit = + let f_offset = total_width self in + let f = {f_name; f_offset; f_def=F_int {width}} in + self.fields <- f :: self.fields; + self.width <- self.width + width + + let empty_name self = + if self.name = "t" then "empty" else spf "empty_%s" self.name + + let gen_ml self : code = + Code.mk_pp @@ fun out () -> + fpf out "@[type %s = int@," self.name; + fpf out "@[let %s : %s = 0@]@," (empty_name self) self.name; + List.iter + (fun f -> + let inline = "[@inline]" in (* TODO: option to enable/disable that *) + let off = f.f_offset in + match f.f_def with + | F_bit -> + let x_lsr = if off = 0 then "x" else spf "(x lsr %d)" off in + fpf out "@[let%s get_%s (x:%s) : bool = (%s land 1) <> 0@]@," + inline f.f_name self.name x_lsr; + let mask_shifted = 1 lsl off in + fpf out "@[<2>let%s set_%s (v:bool) (x:%s) : %s =@ \ + if v then x lor %d else x land (lnot %d)@]@," + inline f.f_name self.name self.name mask_shifted mask_shifted; + | F_int {width} -> + let mask0 = (1 lsl width) - 1 in + fpf out "@[let%s get_%s (x:%s) : int = ((x lsr %d) land %d)@]@," + inline f.f_name self.name off mask0; + fpf out "@[<2>let%s set_%s (i:int) (x:%s) : %s =@ \ + assert ((i land %d) == i);@ \ + ((x land (lnot %d)) lor (i lsl %d))@]@," + inline f.f_name self.name self.name + mask0 (mask0 lsl off) off; + ) + (List.rev self.fields); + (* check width *) + if self.emit_failure_if_too_wide then ( + fpf out "(* check that int size is big enough *)@,\ + @[let () = assert (Sys.int_size >= %d);;@]" (total_width self); + ); + fpf out "@]" + + let gen_mli self : code = + Code.mk_pp @@ fun out () -> + fpf out "@[type %s = private int@," self.name; + fpf out "@[val %s : %s@," (empty_name self) self.name; + List.iter + (fun f -> + match f.f_def with + | F_bit -> + fpf out "@[val get_%s : %s -> bool@]@," f.f_name self.name; + fpf out "@[val set_%s : bool -> %s -> %s@]@," f.f_name self.name self.name; + | F_int {width} -> + fpf out "@[val get_%s : %s -> int@]@," + f.f_name self.name; + fpf out "@,@[(** %d bits integer *)@]@,\ + @[val set_%s : int -> %s -> %s@]@," + width f.f_name self.name self.name; + ) + (List.rev self.fields); + fpf out "@]" +end + +let emit_chan oc cs = + let fmt = Fmt.formatter_of_out_channel oc in + List.iter (fun c -> Fmt.fprintf fmt "@[%a@]@." Code.pp c) cs; + Fmt.fprintf fmt "@?" + +let emit_file file cs = + CCIO.with_out file (fun oc -> emit_chan oc cs) + +let emit_string cs : string = + Fmt.asprintf "@[%a@]" (Fmt.list ~sep:(Fmt.return "@ ") Code.pp) cs + diff --git a/src/codegen/containers_codegen.mli b/src/codegen/containers_codegen.mli new file mode 100644 index 00000000..a5680c50 --- /dev/null +++ b/src/codegen/containers_codegen.mli @@ -0,0 +1,46 @@ + +(** {1 Code generators} *) + +module Fmt = CCFormat + +type code + +(** {2 Representation of OCaml code} *) +module Code : sig + type t = code + + val pp : t Fmt.printer + val to_string : t -> string + + val mk_pp : unit Fmt.printer -> t + val mk_str : string -> t + val in_struct : string -> t list -> t + val in_sig : string -> t list -> t +end + +(** {2 Generate efficient bitfields that fit in an integer} *) +module Bitfield : sig + type t + + val make : + ?emit_failure_if_too_wide:bool -> + name:string -> + unit -> t + (** Make a new bitfield with the given name. + @param name the name of the generated type + @param emit_failure_if_too_wide if true, generated code includes a runtime + assertion that {!Sys.int_size} is wide enough to support this type *) + + val field_bit : t -> string -> unit + val field_int : t -> width:int -> string -> unit + + val total_width : t -> int + + val gen_mli : t -> code + val gen_ml : t -> code +end + +val emit_file : string -> code list -> unit +val emit_chan : out_channel -> code list -> unit +val emit_string : code list -> string + diff --git a/src/codegen/dune b/src/codegen/dune new file mode 100644 index 00000000..61068fc3 --- /dev/null +++ b/src/codegen/dune @@ -0,0 +1,7 @@ + +(library + (name containers_codegen) + (public_name containers.codegen) + (synopsis "code generators for Containers") + (libraries containers) + (flags :standard -warn-error -a+8)) diff --git a/src/codegen/tests/dune b/src/codegen/tests/dune new file mode 100644 index 00000000..320af55e --- /dev/null +++ b/src/codegen/tests/dune @@ -0,0 +1,24 @@ + +; emit tests + +(executable + (name emit_tests) + (modules emit_tests) + (flags :standard -warn-error -a+8) + (libraries containers containers.codegen)) + +(rule + (targets test_bitfield.ml test_bitfield.mli) + (action (run ./emit_tests.exe))) + +; run tests + +(executables + (names test_bitfield) + (modules test_bitfield) + (flags :standard -warn-error -a+8) + (libraries containers)) + +(alias + (name runtest) + (action (run ./test_bitfield.exe))) diff --git a/src/codegen/tests/emit_tests.ml b/src/codegen/tests/emit_tests.ml new file mode 100644 index 00000000..038d67ca --- /dev/null +++ b/src/codegen/tests/emit_tests.ml @@ -0,0 +1,56 @@ +module CG = Containers_codegen +module Vec = CCVector + +let spf = Printf.sprintf + +let emit_bitfields () = + let module B = CG.Bitfield in + let ml = Vec.create() in + let mli = Vec.create() in + begin + let b = B.make ~name:"t" () in + B.field_bit b "x"; + B.field_bit b "y"; + B.field_bit b "z"; + B.field_int b ~width:5 "foo"; + + Vec.push ml (CG.Code.in_struct "T1" [B.gen_ml b]); + Vec.push mli (CG.Code.in_sig "T1" [B.gen_mli b]); + (* check width *) + Vec.push ml + (CG.Code.mk_str (spf "let() = assert (%d = 8);;" (B.total_width b))); + () + end; + + Vec.push ml @@ CG.Code.mk_str {| + let n_fails = ref 0;; + at_exit (fun () -> if !n_fails > 0 then exit 1);; + let assert_true line s = + if not s then ( incr n_fails; Printf.eprintf "test failure at %d\n%!" line);; + + |}; + + let test1 = {| + assert_true __LINE__ T1.(get_y (empty |> set_x true |> set_y true |> set_foo 10));; + assert_true __LINE__ T1.(get_x (empty |> set_x true |> set_y true |> set_foo 10));; + assert_true __LINE__ T1.(get_y (empty |> set_x true |> set_z true + |> set_y false |> set_x false |> set_y true));; + assert_true __LINE__ T1.(get_z (empty |> set_z true));; + assert_true __LINE__ T1.(not @@ get_x (empty |> set_z true));; + assert_true __LINE__ T1.(not @@ get_y (empty |> set_z true |> set_x true));; + assert_true __LINE__ T1.(not @@ get_y (empty |> set_z true |> set_foo 18));; + (* check width of foo *) + assert_true __LINE__ T1.(try ignore (empty |> set_foo (1 lsl 6)); false with _ -> true);; + assert_true __LINE__ T1.(12 = get_foo (empty |> set_x true |> set_foo 12 |> set_x false));; + assert_true __LINE__ T1.(24 = get_foo (empty |> set_y true |> set_foo 24 |> set_z true));; + |} |> CG.Code.mk_str in + Vec.push ml test1; + + CG.emit_file "test_bitfield.ml" (Vec.to_list ml); + CG.emit_file "test_bitfield.mli" (Vec.to_list mli); + () + +let () = + emit_bitfields(); + () +