mirror of
https://github.com/c-cube/ocaml-containers.git
synced 2025-12-09 12:45:34 -05:00
feat: add code-generator for optimal bitfields; add tests
This commit is contained in:
parent
5593e28431
commit
5ad8914e4c
5 changed files with 271 additions and 0 deletions
138
src/codegen/containers_codegen.ml
Normal file
138
src/codegen/containers_codegen.ml
Normal file
|
|
@ -0,0 +1,138 @@
|
||||||
|
|
||||||
|
(** {1 Code generators} *)
|
||||||
|
|
||||||
|
module Fmt = CCFormat
|
||||||
|
let spf = Printf.sprintf
|
||||||
|
let fpf = Fmt.fprintf
|
||||||
|
|
||||||
|
type code =
|
||||||
|
| Base of { pp: unit Fmt.printer }
|
||||||
|
| Struct of string * code list
|
||||||
|
| Sig of string * code list
|
||||||
|
|
||||||
|
module Code = struct
|
||||||
|
type t = code
|
||||||
|
|
||||||
|
let in_struct m (cs:t list) : t = Struct (m, cs)
|
||||||
|
let in_sig m (cs:t list) : t = Sig (m, cs)
|
||||||
|
|
||||||
|
let rec pp_rec out c =
|
||||||
|
let ppl = Fmt.(list ~sep:(return "@ ") pp_rec) in
|
||||||
|
match c with
|
||||||
|
| Base {pp} -> pp out ()
|
||||||
|
| Struct (m,cs) ->
|
||||||
|
fpf out "@[<hv2>module %s = struct@ %a@;<1 -2>end@]" m ppl cs
|
||||||
|
| Sig (m,cs) ->
|
||||||
|
fpf out "@[<hv2>module %s : sig@ %a@;<1 -2>end@]" m ppl cs
|
||||||
|
|
||||||
|
let pp out c = fpf out "@[<v>%a@]" pp_rec c
|
||||||
|
let to_string c = Fmt.to_string pp c
|
||||||
|
|
||||||
|
let mk_pp pp = Base {pp}
|
||||||
|
let mk_str s = Base {pp=Fmt.const Fmt.string s}
|
||||||
|
end
|
||||||
|
|
||||||
|
module Bitfield = struct
|
||||||
|
type field = {
|
||||||
|
f_name: string;
|
||||||
|
f_offset: int;
|
||||||
|
f_def: field_def;
|
||||||
|
}
|
||||||
|
and field_def =
|
||||||
|
| F_bit
|
||||||
|
| F_int of {width: int}
|
||||||
|
|
||||||
|
type t = {
|
||||||
|
name: string;
|
||||||
|
mutable fields: field list;
|
||||||
|
mutable width: int;
|
||||||
|
emit_failure_if_too_wide: bool;
|
||||||
|
}
|
||||||
|
|
||||||
|
let make ?(emit_failure_if_too_wide=true) ~name () : t =
|
||||||
|
{ name; fields=[]; width=0; emit_failure_if_too_wide; }
|
||||||
|
|
||||||
|
let total_width self = self.width
|
||||||
|
|
||||||
|
let field_bit self f_name =
|
||||||
|
let f_offset = total_width self in
|
||||||
|
let f = {f_name; f_offset; f_def=F_bit} in
|
||||||
|
self.fields <- f :: self.fields;
|
||||||
|
self.width <- 1 + self.width
|
||||||
|
|
||||||
|
let field_int self ~width f_name : unit =
|
||||||
|
let f_offset = total_width self in
|
||||||
|
let f = {f_name; f_offset; f_def=F_int {width}} in
|
||||||
|
self.fields <- f :: self.fields;
|
||||||
|
self.width <- self.width + width
|
||||||
|
|
||||||
|
let empty_name self =
|
||||||
|
if self.name = "t" then "empty" else spf "empty_%s" self.name
|
||||||
|
|
||||||
|
let gen_ml self : code =
|
||||||
|
Code.mk_pp @@ fun out () ->
|
||||||
|
fpf out "@[<v>type %s = int@," self.name;
|
||||||
|
fpf out "@[let %s : %s = 0@]@," (empty_name self) self.name;
|
||||||
|
List.iter
|
||||||
|
(fun f ->
|
||||||
|
let inline = "[@inline]" in (* TODO: option to enable/disable that *)
|
||||||
|
let off = f.f_offset in
|
||||||
|
match f.f_def with
|
||||||
|
| F_bit ->
|
||||||
|
let x_lsr = if off = 0 then "x" else spf "(x lsr %d)" off in
|
||||||
|
fpf out "@[let%s get_%s (x:%s) : bool = (%s land 1) <> 0@]@,"
|
||||||
|
inline f.f_name self.name x_lsr;
|
||||||
|
let mask_shifted = 1 lsl off in
|
||||||
|
fpf out "@[<2>let%s set_%s (v:bool) (x:%s) : %s =@ \
|
||||||
|
if v then x lor %d else x land (lnot %d)@]@,"
|
||||||
|
inline f.f_name self.name self.name mask_shifted mask_shifted;
|
||||||
|
| F_int {width} ->
|
||||||
|
let mask0 = (1 lsl width) - 1 in
|
||||||
|
fpf out "@[let%s get_%s (x:%s) : int = ((x lsr %d) land %d)@]@,"
|
||||||
|
inline f.f_name self.name off mask0;
|
||||||
|
fpf out "@[<2>let%s set_%s (i:int) (x:%s) : %s =@ \
|
||||||
|
assert ((i land %d) == i);@ \
|
||||||
|
((x land (lnot %d)) lor (i lsl %d))@]@,"
|
||||||
|
inline f.f_name self.name self.name
|
||||||
|
mask0 (mask0 lsl off) off;
|
||||||
|
)
|
||||||
|
(List.rev self.fields);
|
||||||
|
(* check width *)
|
||||||
|
if self.emit_failure_if_too_wide then (
|
||||||
|
fpf out "(* check that int size is big enough *)@,\
|
||||||
|
@[let () = assert (Sys.int_size >= %d);;@]" (total_width self);
|
||||||
|
);
|
||||||
|
fpf out "@]"
|
||||||
|
|
||||||
|
let gen_mli self : code =
|
||||||
|
Code.mk_pp @@ fun out () ->
|
||||||
|
fpf out "@[<v>type %s = private int@," self.name;
|
||||||
|
fpf out "@[<v>val %s : %s@," (empty_name self) self.name;
|
||||||
|
List.iter
|
||||||
|
(fun f ->
|
||||||
|
match f.f_def with
|
||||||
|
| F_bit ->
|
||||||
|
fpf out "@[val get_%s : %s -> bool@]@," f.f_name self.name;
|
||||||
|
fpf out "@[val set_%s : bool -> %s -> %s@]@," f.f_name self.name self.name;
|
||||||
|
| F_int {width} ->
|
||||||
|
fpf out "@[val get_%s : %s -> int@]@,"
|
||||||
|
f.f_name self.name;
|
||||||
|
fpf out "@,@[(** %d bits integer *)@]@,\
|
||||||
|
@[val set_%s : int -> %s -> %s@]@,"
|
||||||
|
width f.f_name self.name self.name;
|
||||||
|
)
|
||||||
|
(List.rev self.fields);
|
||||||
|
fpf out "@]"
|
||||||
|
end
|
||||||
|
|
||||||
|
let emit_chan oc cs =
|
||||||
|
let fmt = Fmt.formatter_of_out_channel oc in
|
||||||
|
List.iter (fun c -> Fmt.fprintf fmt "@[%a@]@." Code.pp c) cs;
|
||||||
|
Fmt.fprintf fmt "@?"
|
||||||
|
|
||||||
|
let emit_file file cs =
|
||||||
|
CCIO.with_out file (fun oc -> emit_chan oc cs)
|
||||||
|
|
||||||
|
let emit_string cs : string =
|
||||||
|
Fmt.asprintf "@[<v>%a@]" (Fmt.list ~sep:(Fmt.return "@ ") Code.pp) cs
|
||||||
|
|
||||||
46
src/codegen/containers_codegen.mli
Normal file
46
src/codegen/containers_codegen.mli
Normal file
|
|
@ -0,0 +1,46 @@
|
||||||
|
|
||||||
|
(** {1 Code generators} *)
|
||||||
|
|
||||||
|
module Fmt = CCFormat
|
||||||
|
|
||||||
|
type code
|
||||||
|
|
||||||
|
(** {2 Representation of OCaml code} *)
|
||||||
|
module Code : sig
|
||||||
|
type t = code
|
||||||
|
|
||||||
|
val pp : t Fmt.printer
|
||||||
|
val to_string : t -> string
|
||||||
|
|
||||||
|
val mk_pp : unit Fmt.printer -> t
|
||||||
|
val mk_str : string -> t
|
||||||
|
val in_struct : string -> t list -> t
|
||||||
|
val in_sig : string -> t list -> t
|
||||||
|
end
|
||||||
|
|
||||||
|
(** {2 Generate efficient bitfields that fit in an integer} *)
|
||||||
|
module Bitfield : sig
|
||||||
|
type t
|
||||||
|
|
||||||
|
val make :
|
||||||
|
?emit_failure_if_too_wide:bool ->
|
||||||
|
name:string ->
|
||||||
|
unit -> t
|
||||||
|
(** Make a new bitfield with the given name.
|
||||||
|
@param name the name of the generated type
|
||||||
|
@param emit_failure_if_too_wide if true, generated code includes a runtime
|
||||||
|
assertion that {!Sys.int_size} is wide enough to support this type *)
|
||||||
|
|
||||||
|
val field_bit : t -> string -> unit
|
||||||
|
val field_int : t -> width:int -> string -> unit
|
||||||
|
|
||||||
|
val total_width : t -> int
|
||||||
|
|
||||||
|
val gen_mli : t -> code
|
||||||
|
val gen_ml : t -> code
|
||||||
|
end
|
||||||
|
|
||||||
|
val emit_file : string -> code list -> unit
|
||||||
|
val emit_chan : out_channel -> code list -> unit
|
||||||
|
val emit_string : code list -> string
|
||||||
|
|
||||||
7
src/codegen/dune
Normal file
7
src/codegen/dune
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
|
||||||
|
(library
|
||||||
|
(name containers_codegen)
|
||||||
|
(public_name containers.codegen)
|
||||||
|
(synopsis "code generators for Containers")
|
||||||
|
(libraries containers)
|
||||||
|
(flags :standard -warn-error -a+8))
|
||||||
24
src/codegen/tests/dune
Normal file
24
src/codegen/tests/dune
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
|
||||||
|
; emit tests
|
||||||
|
|
||||||
|
(executable
|
||||||
|
(name emit_tests)
|
||||||
|
(modules emit_tests)
|
||||||
|
(flags :standard -warn-error -a+8)
|
||||||
|
(libraries containers containers.codegen))
|
||||||
|
|
||||||
|
(rule
|
||||||
|
(targets test_bitfield.ml test_bitfield.mli)
|
||||||
|
(action (run ./emit_tests.exe)))
|
||||||
|
|
||||||
|
; run tests
|
||||||
|
|
||||||
|
(executables
|
||||||
|
(names test_bitfield)
|
||||||
|
(modules test_bitfield)
|
||||||
|
(flags :standard -warn-error -a+8)
|
||||||
|
(libraries containers))
|
||||||
|
|
||||||
|
(alias
|
||||||
|
(name runtest)
|
||||||
|
(action (run ./test_bitfield.exe)))
|
||||||
56
src/codegen/tests/emit_tests.ml
Normal file
56
src/codegen/tests/emit_tests.ml
Normal file
|
|
@ -0,0 +1,56 @@
|
||||||
|
module CG = Containers_codegen
|
||||||
|
module Vec = CCVector
|
||||||
|
|
||||||
|
let spf = Printf.sprintf
|
||||||
|
|
||||||
|
let emit_bitfields () =
|
||||||
|
let module B = CG.Bitfield in
|
||||||
|
let ml = Vec.create() in
|
||||||
|
let mli = Vec.create() in
|
||||||
|
begin
|
||||||
|
let b = B.make ~name:"t" () in
|
||||||
|
B.field_bit b "x";
|
||||||
|
B.field_bit b "y";
|
||||||
|
B.field_bit b "z";
|
||||||
|
B.field_int b ~width:5 "foo";
|
||||||
|
|
||||||
|
Vec.push ml (CG.Code.in_struct "T1" [B.gen_ml b]);
|
||||||
|
Vec.push mli (CG.Code.in_sig "T1" [B.gen_mli b]);
|
||||||
|
(* check width *)
|
||||||
|
Vec.push ml
|
||||||
|
(CG.Code.mk_str (spf "let() = assert (%d = 8);;" (B.total_width b)));
|
||||||
|
()
|
||||||
|
end;
|
||||||
|
|
||||||
|
Vec.push ml @@ CG.Code.mk_str {|
|
||||||
|
let n_fails = ref 0;;
|
||||||
|
at_exit (fun () -> if !n_fails > 0 then exit 1);;
|
||||||
|
let assert_true line s =
|
||||||
|
if not s then ( incr n_fails; Printf.eprintf "test failure at %d\n%!" line);;
|
||||||
|
|
||||||
|
|};
|
||||||
|
|
||||||
|
let test1 = {|
|
||||||
|
assert_true __LINE__ T1.(get_y (empty |> set_x true |> set_y true |> set_foo 10));;
|
||||||
|
assert_true __LINE__ T1.(get_x (empty |> set_x true |> set_y true |> set_foo 10));;
|
||||||
|
assert_true __LINE__ T1.(get_y (empty |> set_x true |> set_z true
|
||||||
|
|> set_y false |> set_x false |> set_y true));;
|
||||||
|
assert_true __LINE__ T1.(get_z (empty |> set_z true));;
|
||||||
|
assert_true __LINE__ T1.(not @@ get_x (empty |> set_z true));;
|
||||||
|
assert_true __LINE__ T1.(not @@ get_y (empty |> set_z true |> set_x true));;
|
||||||
|
assert_true __LINE__ T1.(not @@ get_y (empty |> set_z true |> set_foo 18));;
|
||||||
|
(* check width of foo *)
|
||||||
|
assert_true __LINE__ T1.(try ignore (empty |> set_foo (1 lsl 6)); false with _ -> true);;
|
||||||
|
assert_true __LINE__ T1.(12 = get_foo (empty |> set_x true |> set_foo 12 |> set_x false));;
|
||||||
|
assert_true __LINE__ T1.(24 = get_foo (empty |> set_y true |> set_foo 24 |> set_z true));;
|
||||||
|
|} |> CG.Code.mk_str in
|
||||||
|
Vec.push ml test1;
|
||||||
|
|
||||||
|
CG.emit_file "test_bitfield.ml" (Vec.to_list ml);
|
||||||
|
CG.emit_file "test_bitfield.mli" (Vec.to_list mli);
|
||||||
|
()
|
||||||
|
|
||||||
|
let () =
|
||||||
|
emit_bitfields();
|
||||||
|
()
|
||||||
|
|
||||||
Loading…
Add table
Reference in a new issue