CCGraph: more functions, better interface for traversals

This commit is contained in:
Simon Cruanes 2015-06-10 14:21:23 +02:00
parent d7b15ca81e
commit 20d72e5755
2 changed files with 378 additions and 101 deletions

View file

@ -35,6 +35,7 @@ module Seq = struct
let return x k = k x let return x k = k x
let (>>=) a f k = a (fun x -> f x k) let (>>=) a f k = a (fun x -> f x k)
let map f a k = a (fun x -> k (f x)) let map f a k = a (fun x -> k (f x))
let filter_map f a k = a (fun x -> match f x with None -> () | Some y -> k y)
let iter f a = a f let iter f a = a f
let fold f acc a = let fold f acc a =
let acc = ref acc in let acc = ref acc in
@ -51,10 +52,12 @@ type ('v, 'e) t = {
dest: 'e -> 'v; dest: 'e -> 'v;
} }
type ('v, 'e) graph = ('v, 'e) t
(** Mutable bitset for values of type ['v] *) (** Mutable bitset for values of type ['v] *)
type 'v tag_set = { type 'v tag_set = {
get_tag: 'v -> bool; get_tag: 'v -> bool;
set_tag: 'v -> unit; (** Set tag to [true] for the given element *) set_tag: 'v -> unit; (** Set tag for the given element *)
} }
(** Mutable table with keys ['k] and values ['a] *) (** Mutable table with keys ['k] and values ['a] *)
@ -81,7 +84,19 @@ let mk_table (type k) ?(eq=(=)) ?(hash=Hashtbl.hash) size =
; size=(fun () -> H.length tbl) ; size=(fun () -> H.length tbl)
} }
(** {2 Traversals} *) let mk_map (type k) ?(cmp=Pervasives.compare) () =
let module M = Map.Make(struct
type t = k
let compare = cmp
end) in
let tbl = ref M.empty in
{ mem=(fun k -> M.mem k !tbl)
; find=(fun k -> M.find k !tbl)
; add=(fun k v -> tbl := M.add k v !tbl)
; size=(fun () -> M.cardinal !tbl)
}
(** {2 Bags} *)
type 'a bag = { type 'a bag = {
push: 'a -> unit; push: 'a -> unit;
@ -140,24 +155,10 @@ let mk_heap ~leq =
) )
} }
let traverse ?tbl:(mk_tbl=mk_table ?eq:None ?hash:None) ~bag:mk_bag ~graph seq = (** {2 Traversals} *)
fun k ->
let bag = mk_bag() in
Seq.iter bag.push seq;
let tbl = mk_tbl 128 in
let bag = mk_bag () in
while not (bag.is_empty ()) do
let x = bag.pop () in
if not (tbl.mem x) then (
k x;
tbl.add x ();
Seq.iter
(fun e -> bag.push (graph.dest e))
(graph.children x)
)
done
let traverse_tag ~tags ~bag ~graph seq = module Traverse = struct
let generic_tag ~tags ~bag ~graph seq =
let first = ref true in let first = ref true in
fun k -> fun k ->
(* ensure linearity *) (* ensure linearity *)
@ -174,26 +175,23 @@ let traverse_tag ~tags ~bag ~graph seq =
) )
done done
let generic ?(tbl=mk_table 128) ~bag ~graph seq =
let tags = {
get_tag=tbl.mem;
set_tag=(fun v -> tbl.add v ());
} in
generic_tag ~tags ~bag ~graph seq
let bfs ?tbl ~graph seq = let bfs ?tbl ~graph seq =
traverse ?tbl ~bag:mk_queue ~graph seq generic ?tbl ~bag:(mk_queue ()) ~graph seq
let bfs_tag ~tags ~graph seq = let bfs_tag ~tags ~graph seq =
traverse_tag ~tags ~bag:(mk_queue()) ~graph seq generic_tag ~tags ~bag:(mk_queue()) ~graph seq
let dfs ?tbl ~graph seq = let dijkstra_tag ?(dist=fun _ -> 1) ~tags ~graph seq =
traverse ?tbl ~bag:mk_stack ~graph seq let tags' = {
get_tag=(fun (v,_) -> tags.get_tag v);
let dfs_tag ~tags ~graph seq = set_tag=(fun (v,_) -> tags.set_tag v);
traverse_tag ~tags ~bag:(mk_stack()) ~graph seq
let dijkstra ?(tbl=mk_table ?eq:None ?hash:None) ?(dist=fun _ -> 1) ~graph seq =
(* a table [('v * int) -> 'a] built from a ['v -> 'a] table *)
let mk_tbl' size =
let vertex_tbl = tbl size in
{ mem=(fun (v, _) -> vertex_tbl.mem v)
; find=(fun (v, _) -> vertex_tbl.find v)
; add=(fun (v, _) -> vertex_tbl.add v)
; size=vertex_tbl.size
} }
and seq' = Seq.map (fun v -> v, 0) seq and seq' = Seq.map (fun v -> v, 0) seq
and graph' = { and graph' = {
@ -201,12 +199,199 @@ let dijkstra ?(tbl=mk_table ?eq:None ?hash:None) ?(dist=fun _ -> 1) ~graph seq =
origin=(fun (e, d) -> graph.origin e, d); origin=(fun (e, d) -> graph.origin e, d);
dest=(fun (e, d) -> graph.dest e, d + dist e); dest=(fun (e, d) -> graph.dest e, d + dist e);
} in } in
let mk_bag () = mk_heap ~leq:(fun (_, d1) (_, d2) -> d1 <= d2) in let bag = mk_heap ~leq:(fun (_, d1) (_, d2) -> d1 <= d2) in
traverse ~tbl:mk_tbl' ~bag:mk_bag ~graph:graph' seq' generic_tag ~tags:tags' ~bag ~graph:graph' seq'
let dijkstra_tag ?(dist=fun _ -> 1) ~tags ~graph seq = assert false (* TODO *)
let dijkstra ?(tbl=mk_table 128) ?dist ~graph seq =
let tags = {
get_tag=tbl.mem;
set_tag=(fun v -> tbl.add v ());
} in
dijkstra_tag ~tags ?dist ~graph seq
let dfs ?tbl ~graph seq =
generic ?tbl ~bag:(mk_stack ()) ~graph seq
let dfs_tag ~tags ~graph seq =
generic_tag ~tags ~bag:(mk_stack()) ~graph seq
module Event = struct
type edge_kind = [`Forward | `Back | `Cross ]
type 'e path = 'e list
(** A traversal is a sequence of such events *)
type ('v,'e) t =
[ `Enter of 'v * int * 'e path (* unique index in traversal, path from start *)
| `Exit of 'v
| `Edge of 'e * edge_kind
]
let get_vertex = function
| `Enter (v, _, _) -> Some (v, `Enter)
| `Exit v -> Some (v, `Exit)
| `Edge _ -> None
let get_enter = function
| `Enter (v, _, _) -> Some v
| `Exit _
| `Edge _ -> None
let get_exit = function
| `Exit v -> Some v
| `Enter _
| `Edge _ -> None
let get_edge = function
| `Edge (e, _) -> Some e
| `Enter _
| `Exit _ -> None
let get_edge_kind = function
| `Edge (e, k) -> Some (e, k)
| `Enter _
| `Exit _ -> None
(* is [v] the origin of some edge in [path]? *)
let rec list_mem_ ~eq ~graph v path = match path with
| [] -> false
| e :: path' ->
eq v (graph.origin e) || list_mem_ ~eq ~graph v path'
let dfs_tag ?(eq=(=)) ~tags ~graph seq =
let first = ref true in
fun k ->
if !first then first := false else raise Sequence_once;
let bag = mk_stack() in
let n = ref 0 in
Seq.iter
(fun v ->
(* start DFS from this vertex *)
bag.push (`Enter (v, []));
while not (bag.is_empty ()) do
match bag.pop () with
| `Enter (x, path) ->
if not (tags.get_tag x) then (
let num = !n in
incr n;
tags.set_tag x;
k (`Enter (x, num, path));
bag.push (`Exit x);
Seq.iter
(fun e -> bag.push (`Edge (e, e :: path)))
(graph.children x);
)
| `Exit x -> k (`Exit x)
| `Edge (e, path) ->
let v = graph.dest e in
let edge_kind =
if tags.get_tag v
then if list_mem_ ~eq ~graph v path
then `Back
else `Cross
else `Forward
in
k (`Edge (e, edge_kind))
done
) seq
let dfs ?(tbl=mk_table 128) ?eq ~graph seq =
let tags = {
set_tag=(fun v -> tbl.add v ());
get_tag=tbl.mem;
} in
dfs_tag ?eq ~tags ~graph seq
end
end
module Dot = struct
type attribute = [
| `Color of string
| `Shape of string
| `Weight of int
| `Style of string
| `Label of string
| `Other of string * string
] (** Dot attribute *)
let pp_list pp_x out l =
Format.pp_print_string out "[";
List.iteri (fun i x ->
if i > 0 then Format.fprintf out ",@;";
pp_x out x
) l;
Format.pp_print_string out "]"
(** Print an enum of Full.traverse_event *)
let pp_seq
?(tbl=mk_table 128)
?(attrs_v=fun _ -> [])
?(attrs_e=fun _ -> [])
?(name="graph")
~graph out seq =
(* print an attribute *)
let pp_attr out attr = match attr with
| `Color c -> Format.fprintf out "color=%s" c
| `Shape s -> Format.fprintf out "shape=%s" s
| `Weight w -> Format.fprintf out "weight=%d" w
| `Style s -> Format.fprintf out "style=%s" s
| `Label l -> Format.fprintf out "label=\"%s\"" l
| `Other (name, value) -> Format.fprintf out "%s=\"%s\"" name value
(* map from vertices to integers *)
and get_id =
let count = ref 0 in
fun v ->
try tbl.find v
with Not_found ->
let n = !count in
incr count;
tbl.add v n;
n
in
(* the unique name of a vertex *)
let pp_vertex out v = Format.fprintf out "vertex_%d" (get_id v) in
(* print preamble *)
Format.fprintf out "@[<v2>digraph %s {@;" name;
(* traverse *)
let tags = {
get_tag=tbl.mem;
set_tag=(fun v -> ignore (get_id v)); (* allocate new ID *)
} in
let events = Traverse.Event.dfs_tag ~tags ~graph seq in
Seq.iter
(function
| `Enter (v, _n, _path) ->
let attrs = attrs_v v in
Format.fprintf out " @[<h>%a %a;@]@." pp_vertex v (pp_list pp_attr) attrs
| `Exit _ -> ()
| `Edge (e, _) ->
let v1 = graph.origin e in
let v2 = graph.dest e in
let attrs = attrs_e e in
Format.fprintf out " @[<h>%a -> %a %a;@]@."
pp_vertex v1 pp_vertex v2
(pp_list pp_attr)
attrs
) events;
(* close *)
Format.fprintf out "}@]@;@?";
()
let pp ?tbl ?attrs_v ?attrs_e ?name ~graph fmt v =
pp_seq ?tbl ?attrs_v ?attrs_e ?name ~graph fmt (Seq.return v)
let with_out filename f =
let oc = open_out filename in
try
let fmt = Format.formatter_of_out_channel oc in
let x = f fmt in
Format.pp_print_flush fmt ();
close_out oc;
x
with e ->
close_out oc;
raise e
end

View file

@ -40,6 +40,7 @@ module Seq : sig
val return : 'a -> 'a sequence val return : 'a -> 'a sequence
val (>>=) : 'a t -> ('a -> 'b t) -> 'b t val (>>=) : 'a t -> ('a -> 'b t) -> 'b t
val map : ('a -> 'b) -> 'a t -> 'b t val map : ('a -> 'b) -> 'a t -> 'b t
val filter_map : ('a -> 'b option) -> 'a t -> 'b t
val iter : ('a -> unit) -> 'a t -> unit val iter : ('a -> unit) -> 'a t -> unit
val fold: ('b -> 'a -> 'b) -> 'b -> 'a t -> 'b val fold: ('b -> 'a -> 'b) -> 'b -> 'a t -> 'b
end end
@ -53,10 +54,12 @@ type ('v, 'e) t = {
dest: 'e -> 'v; dest: 'e -> 'v;
} }
(** Mutable bitset for values of type ['v] *) type ('v, 'e) graph = ('v, 'e) t
(** Mutable tags from values of type ['v] to tags of type [bool] *)
type 'v tag_set = { type 'v tag_set = {
get_tag: 'v -> bool; get_tag: 'v -> bool;
set_tag: 'v -> unit; (** Set tag to [true] for the given element *) set_tag: 'v -> unit; (** Set tag for the given element *)
} }
(** Mutable table with keys ['k] and values ['a] *) (** Mutable table with keys ['k] and values ['a] *)
@ -70,10 +73,13 @@ type ('k, 'a) table = {
(** Mutable set *) (** Mutable set *)
type 'a set = ('a, unit) table type 'a set = ('a, unit) table
(** Default implementation for {!table}: a {!Hashtbl.t} *)
val mk_table: ?eq:('k -> 'k -> bool) -> ?hash:('k -> int) -> int -> ('k, 'a) table val mk_table: ?eq:('k -> 'k -> bool) -> ?hash:('k -> int) -> int -> ('k, 'a) table
(** Default implementation for {!table}: a {!Hashtbl.t} *)
(** {2 Traversals} *) val mk_map: ?cmp:('k -> 'k -> int) -> unit -> ('k, 'a) table
(** Use a {!Map.S} underneath *)
(** {2 Bags of vertices} *)
(** Bag of elements of type ['a] *) (** Bag of elements of type ['a] *)
type 'a bag = { type 'a bag = {
@ -89,34 +95,50 @@ val mk_heap: leq:('a -> 'a -> bool) -> 'a bag
(** [mk_heap ~leq] makes a priority queue where [leq x y = true] means that (** [mk_heap ~leq] makes a priority queue where [leq x y = true] means that
[x] is smaller than [y] and should be prioritary *) [x] is smaller than [y] and should be prioritary *)
val traverse: ?tbl:(int -> 'v set) -> (** {2 Traversals} *)
bag:(unit -> 'v bag) ->
module Traverse : sig
val generic: ?tbl:'v set ->
bag:'v bag ->
graph:('v, 'e) t -> graph:('v, 'e) t ->
'v sequence -> 'v sequence 'v sequence ->
'v sequence_once
(** Traversal of the given graph, starting from a sequence (** Traversal of the given graph, starting from a sequence
of vertices, using the given bag to choose the next vertex to of vertices, using the given bag to choose the next vertex to
explore. Each vertex is visited at most once. *) explore. Each vertex is visited at most once. *)
val traverse_tag: tags:'v tag_set -> val generic_tag: tags:'v tag_set ->
bag:'v bag -> bag:'v bag ->
graph:('v, 'e) t -> graph:('v, 'e) t ->
'v sequence -> 'v sequence ->
'v sequence_once 'v sequence_once
(** One-shot traversal of the graph using a tag set and the given bag *) (** One-shot traversal of the graph using a tag set and the given bag *)
val bfs: ?tbl:(int -> 'v set) -> graph:('v, 'e) t -> 'v sequence -> 'v sequence val dfs: ?tbl:'v set ->
graph:('v, 'e) t ->
'v sequence ->
'v sequence_once
val bfs_tag: tags:'v tag_set -> graph:('v, 'e) t -> 'v sequence -> 'v sequence_once val dfs_tag: tags:'v tag_set ->
graph:('v, 'e) t ->
'v sequence ->
'v sequence_once
val dfs: ?tbl:(int -> 'v set) -> graph:('v, 'e) t -> 'v sequence -> 'v sequence val bfs: ?tbl:'v set ->
graph:('v, 'e) t ->
'v sequence ->
'v sequence_once
val dfs_tag: tags:'v tag_set -> graph:('v, 'e) t -> 'v sequence -> 'v sequence_once val bfs_tag: tags:'v tag_set ->
graph:('v, 'e) t ->
'v sequence ->
'v sequence_once
val dijkstra : ?tbl:(int -> 'v set) -> val dijkstra : ?tbl:'v set ->
?dist:('e -> int) -> ?dist:('e -> int) ->
graph:('v, 'e) t -> graph:('v, 'e) t ->
'v sequence -> 'v sequence ->
('v * int) sequence ('v * int) sequence_once
(** Dijkstra algorithm, traverses a graph in increasing distance order. (** Dijkstra algorithm, traverses a graph in increasing distance order.
Yields each vertex paired with its distance to the set of initial vertices Yields each vertex paired with its distance to the set of initial vertices
(the smallest distance needed to reach the node from the initial vertices) (the smallest distance needed to reach the node from the initial vertices)
@ -129,3 +151,73 @@ val dijkstra_tag : ?dist:('e -> int) ->
'v sequence -> 'v sequence ->
('v * int) sequence_once ('v * int) sequence_once
(** {2 More detailed interface} *)
module Event : sig
type edge_kind = [`Forward | `Back | `Cross ]
type 'e path = 'e list
(** A traversal is a sequence of such events *)
type ('v,'e) t =
[ `Enter of 'v * int * 'e path (* unique index in traversal, path from start *)
| `Exit of 'v
| `Edge of 'e * edge_kind
]
val get_vertex : ('v, 'e) t -> ('v * [`Enter | `Exit]) option
val get_enter : ('v, 'e) t -> 'v option
val get_exit : ('v, 'e) t -> 'v option
val get_edge : ('v, 'e) t -> 'e option
val get_edge_kind : ('v, 'e) t -> ('e * edge_kind) option
val dfs: ?tbl:'v set ->
?eq:('v -> 'v -> bool) ->
graph:('v, 'e) graph ->
'v sequence ->
('v,'e) t sequence_once
(** Full version of DFS.
@param eq equality predicate on vertices *)
val dfs_tag: ?eq:('v -> 'v -> bool) ->
tags:'v tag_set ->
graph:('v, 'e) graph ->
'v sequence ->
('v,'e) t sequence_once
(** Full version of DFS using integer tags
@param eq equality predicate on vertices *)
end
end
(** {2 Pretty printing in the DOT (graphviz) format} *)
module Dot : sig
type attribute = [
| `Color of string
| `Shape of string
| `Weight of int
| `Style of string
| `Label of string
| `Other of string * string
] (** Dot attribute *)
val pp : ?tbl:('v,int) table ->
?attrs_v:('v -> attribute list) ->
?attrs_e:('e -> attribute list) ->
?name:string ->
graph:('v,'e) t ->
Format.formatter ->
'v ->
unit
val pp_seq : ?tbl:('v,int) table ->
?attrs_v:('v -> attribute list) ->
?attrs_e:('e -> attribute list) ->
?name:string ->
graph:('v,'e) t ->
Format.formatter ->
'v sequence ->
unit
val with_out : string -> (Format.formatter -> 'a) -> 'a
(** Shortcut to open a file and write to it *)
end