mirror of
https://github.com/c-cube/ocaml-containers.git
synced 2025-12-09 12:45:34 -05:00
feat: add code-generator for optimal bitfields; add tests
This commit is contained in:
parent
5593e28431
commit
5ad8914e4c
5 changed files with 271 additions and 0 deletions
138
src/codegen/containers_codegen.ml
Normal file
138
src/codegen/containers_codegen.ml
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
|
||||
(** {1 Code generators} *)
|
||||
|
||||
module Fmt = CCFormat
|
||||
let spf = Printf.sprintf
|
||||
let fpf = Fmt.fprintf
|
||||
|
||||
type code =
|
||||
| Base of { pp: unit Fmt.printer }
|
||||
| Struct of string * code list
|
||||
| Sig of string * code list
|
||||
|
||||
module Code = struct
|
||||
type t = code
|
||||
|
||||
let in_struct m (cs:t list) : t = Struct (m, cs)
|
||||
let in_sig m (cs:t list) : t = Sig (m, cs)
|
||||
|
||||
let rec pp_rec out c =
|
||||
let ppl = Fmt.(list ~sep:(return "@ ") pp_rec) in
|
||||
match c with
|
||||
| Base {pp} -> pp out ()
|
||||
| Struct (m,cs) ->
|
||||
fpf out "@[<hv2>module %s = struct@ %a@;<1 -2>end@]" m ppl cs
|
||||
| Sig (m,cs) ->
|
||||
fpf out "@[<hv2>module %s : sig@ %a@;<1 -2>end@]" m ppl cs
|
||||
|
||||
let pp out c = fpf out "@[<v>%a@]" pp_rec c
|
||||
let to_string c = Fmt.to_string pp c
|
||||
|
||||
let mk_pp pp = Base {pp}
|
||||
let mk_str s = Base {pp=Fmt.const Fmt.string s}
|
||||
end
|
||||
|
||||
module Bitfield = struct
|
||||
type field = {
|
||||
f_name: string;
|
||||
f_offset: int;
|
||||
f_def: field_def;
|
||||
}
|
||||
and field_def =
|
||||
| F_bit
|
||||
| F_int of {width: int}
|
||||
|
||||
type t = {
|
||||
name: string;
|
||||
mutable fields: field list;
|
||||
mutable width: int;
|
||||
emit_failure_if_too_wide: bool;
|
||||
}
|
||||
|
||||
let make ?(emit_failure_if_too_wide=true) ~name () : t =
|
||||
{ name; fields=[]; width=0; emit_failure_if_too_wide; }
|
||||
|
||||
let total_width self = self.width
|
||||
|
||||
let field_bit self f_name =
|
||||
let f_offset = total_width self in
|
||||
let f = {f_name; f_offset; f_def=F_bit} in
|
||||
self.fields <- f :: self.fields;
|
||||
self.width <- 1 + self.width
|
||||
|
||||
let field_int self ~width f_name : unit =
|
||||
let f_offset = total_width self in
|
||||
let f = {f_name; f_offset; f_def=F_int {width}} in
|
||||
self.fields <- f :: self.fields;
|
||||
self.width <- self.width + width
|
||||
|
||||
let empty_name self =
|
||||
if self.name = "t" then "empty" else spf "empty_%s" self.name
|
||||
|
||||
let gen_ml self : code =
|
||||
Code.mk_pp @@ fun out () ->
|
||||
fpf out "@[<v>type %s = int@," self.name;
|
||||
fpf out "@[let %s : %s = 0@]@," (empty_name self) self.name;
|
||||
List.iter
|
||||
(fun f ->
|
||||
let inline = "[@inline]" in (* TODO: option to enable/disable that *)
|
||||
let off = f.f_offset in
|
||||
match f.f_def with
|
||||
| F_bit ->
|
||||
let x_lsr = if off = 0 then "x" else spf "(x lsr %d)" off in
|
||||
fpf out "@[let%s get_%s (x:%s) : bool = (%s land 1) <> 0@]@,"
|
||||
inline f.f_name self.name x_lsr;
|
||||
let mask_shifted = 1 lsl off in
|
||||
fpf out "@[<2>let%s set_%s (v:bool) (x:%s) : %s =@ \
|
||||
if v then x lor %d else x land (lnot %d)@]@,"
|
||||
inline f.f_name self.name self.name mask_shifted mask_shifted;
|
||||
| F_int {width} ->
|
||||
let mask0 = (1 lsl width) - 1 in
|
||||
fpf out "@[let%s get_%s (x:%s) : int = ((x lsr %d) land %d)@]@,"
|
||||
inline f.f_name self.name off mask0;
|
||||
fpf out "@[<2>let%s set_%s (i:int) (x:%s) : %s =@ \
|
||||
assert ((i land %d) == i);@ \
|
||||
((x land (lnot %d)) lor (i lsl %d))@]@,"
|
||||
inline f.f_name self.name self.name
|
||||
mask0 (mask0 lsl off) off;
|
||||
)
|
||||
(List.rev self.fields);
|
||||
(* check width *)
|
||||
if self.emit_failure_if_too_wide then (
|
||||
fpf out "(* check that int size is big enough *)@,\
|
||||
@[let () = assert (Sys.int_size >= %d);;@]" (total_width self);
|
||||
);
|
||||
fpf out "@]"
|
||||
|
||||
let gen_mli self : code =
|
||||
Code.mk_pp @@ fun out () ->
|
||||
fpf out "@[<v>type %s = private int@," self.name;
|
||||
fpf out "@[<v>val %s : %s@," (empty_name self) self.name;
|
||||
List.iter
|
||||
(fun f ->
|
||||
match f.f_def with
|
||||
| F_bit ->
|
||||
fpf out "@[val get_%s : %s -> bool@]@," f.f_name self.name;
|
||||
fpf out "@[val set_%s : bool -> %s -> %s@]@," f.f_name self.name self.name;
|
||||
| F_int {width} ->
|
||||
fpf out "@[val get_%s : %s -> int@]@,"
|
||||
f.f_name self.name;
|
||||
fpf out "@,@[(** %d bits integer *)@]@,\
|
||||
@[val set_%s : int -> %s -> %s@]@,"
|
||||
width f.f_name self.name self.name;
|
||||
)
|
||||
(List.rev self.fields);
|
||||
fpf out "@]"
|
||||
end
|
||||
|
||||
let emit_chan oc cs =
|
||||
let fmt = Fmt.formatter_of_out_channel oc in
|
||||
List.iter (fun c -> Fmt.fprintf fmt "@[%a@]@." Code.pp c) cs;
|
||||
Fmt.fprintf fmt "@?"
|
||||
|
||||
let emit_file file cs =
|
||||
CCIO.with_out file (fun oc -> emit_chan oc cs)
|
||||
|
||||
let emit_string cs : string =
|
||||
Fmt.asprintf "@[<v>%a@]" (Fmt.list ~sep:(Fmt.return "@ ") Code.pp) cs
|
||||
|
||||
46
src/codegen/containers_codegen.mli
Normal file
46
src/codegen/containers_codegen.mli
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
|
||||
(** {1 Code generators} *)
|
||||
|
||||
module Fmt = CCFormat
|
||||
|
||||
type code
|
||||
|
||||
(** {2 Representation of OCaml code} *)
|
||||
module Code : sig
|
||||
type t = code
|
||||
|
||||
val pp : t Fmt.printer
|
||||
val to_string : t -> string
|
||||
|
||||
val mk_pp : unit Fmt.printer -> t
|
||||
val mk_str : string -> t
|
||||
val in_struct : string -> t list -> t
|
||||
val in_sig : string -> t list -> t
|
||||
end
|
||||
|
||||
(** {2 Generate efficient bitfields that fit in an integer} *)
|
||||
module Bitfield : sig
|
||||
type t
|
||||
|
||||
val make :
|
||||
?emit_failure_if_too_wide:bool ->
|
||||
name:string ->
|
||||
unit -> t
|
||||
(** Make a new bitfield with the given name.
|
||||
@param name the name of the generated type
|
||||
@param emit_failure_if_too_wide if true, generated code includes a runtime
|
||||
assertion that {!Sys.int_size} is wide enough to support this type *)
|
||||
|
||||
val field_bit : t -> string -> unit
|
||||
val field_int : t -> width:int -> string -> unit
|
||||
|
||||
val total_width : t -> int
|
||||
|
||||
val gen_mli : t -> code
|
||||
val gen_ml : t -> code
|
||||
end
|
||||
|
||||
val emit_file : string -> code list -> unit
|
||||
val emit_chan : out_channel -> code list -> unit
|
||||
val emit_string : code list -> string
|
||||
|
||||
7
src/codegen/dune
Normal file
7
src/codegen/dune
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
|
||||
(library
|
||||
(name containers_codegen)
|
||||
(public_name containers.codegen)
|
||||
(synopsis "code generators for Containers")
|
||||
(libraries containers)
|
||||
(flags :standard -warn-error -a+8))
|
||||
24
src/codegen/tests/dune
Normal file
24
src/codegen/tests/dune
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
|
||||
; emit tests
|
||||
|
||||
(executable
|
||||
(name emit_tests)
|
||||
(modules emit_tests)
|
||||
(flags :standard -warn-error -a+8)
|
||||
(libraries containers containers.codegen))
|
||||
|
||||
(rule
|
||||
(targets test_bitfield.ml test_bitfield.mli)
|
||||
(action (run ./emit_tests.exe)))
|
||||
|
||||
; run tests
|
||||
|
||||
(executables
|
||||
(names test_bitfield)
|
||||
(modules test_bitfield)
|
||||
(flags :standard -warn-error -a+8)
|
||||
(libraries containers))
|
||||
|
||||
(alias
|
||||
(name runtest)
|
||||
(action (run ./test_bitfield.exe)))
|
||||
56
src/codegen/tests/emit_tests.ml
Normal file
56
src/codegen/tests/emit_tests.ml
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
module CG = Containers_codegen
|
||||
module Vec = CCVector
|
||||
|
||||
let spf = Printf.sprintf
|
||||
|
||||
let emit_bitfields () =
|
||||
let module B = CG.Bitfield in
|
||||
let ml = Vec.create() in
|
||||
let mli = Vec.create() in
|
||||
begin
|
||||
let b = B.make ~name:"t" () in
|
||||
B.field_bit b "x";
|
||||
B.field_bit b "y";
|
||||
B.field_bit b "z";
|
||||
B.field_int b ~width:5 "foo";
|
||||
|
||||
Vec.push ml (CG.Code.in_struct "T1" [B.gen_ml b]);
|
||||
Vec.push mli (CG.Code.in_sig "T1" [B.gen_mli b]);
|
||||
(* check width *)
|
||||
Vec.push ml
|
||||
(CG.Code.mk_str (spf "let() = assert (%d = 8);;" (B.total_width b)));
|
||||
()
|
||||
end;
|
||||
|
||||
Vec.push ml @@ CG.Code.mk_str {|
|
||||
let n_fails = ref 0;;
|
||||
at_exit (fun () -> if !n_fails > 0 then exit 1);;
|
||||
let assert_true line s =
|
||||
if not s then ( incr n_fails; Printf.eprintf "test failure at %d\n%!" line);;
|
||||
|
||||
|};
|
||||
|
||||
let test1 = {|
|
||||
assert_true __LINE__ T1.(get_y (empty |> set_x true |> set_y true |> set_foo 10));;
|
||||
assert_true __LINE__ T1.(get_x (empty |> set_x true |> set_y true |> set_foo 10));;
|
||||
assert_true __LINE__ T1.(get_y (empty |> set_x true |> set_z true
|
||||
|> set_y false |> set_x false |> set_y true));;
|
||||
assert_true __LINE__ T1.(get_z (empty |> set_z true));;
|
||||
assert_true __LINE__ T1.(not @@ get_x (empty |> set_z true));;
|
||||
assert_true __LINE__ T1.(not @@ get_y (empty |> set_z true |> set_x true));;
|
||||
assert_true __LINE__ T1.(not @@ get_y (empty |> set_z true |> set_foo 18));;
|
||||
(* check width of foo *)
|
||||
assert_true __LINE__ T1.(try ignore (empty |> set_foo (1 lsl 6)); false with _ -> true);;
|
||||
assert_true __LINE__ T1.(12 = get_foo (empty |> set_x true |> set_foo 12 |> set_x false));;
|
||||
assert_true __LINE__ T1.(24 = get_foo (empty |> set_y true |> set_foo 24 |> set_z true));;
|
||||
|} |> CG.Code.mk_str in
|
||||
Vec.push ml test1;
|
||||
|
||||
CG.emit_file "test_bitfield.ml" (Vec.to_list ml);
|
||||
CG.emit_file "test_bitfield.mli" (Vec.to_list mli);
|
||||
()
|
||||
|
||||
let () =
|
||||
emit_bitfields();
|
||||
()
|
||||
|
||||
Loading…
Add table
Reference in a new issue