refactor: also port CCGraph to iter

This commit is contained in:
Simon Cruanes 2019-12-14 16:41:49 -06:00
parent 264e89b198
commit 1947d1804b
2 changed files with 181 additions and 103 deletions

View file

@ -5,16 +5,30 @@
(** {2 Iter Helpers} *)
type 'a sequence = ('a -> unit) -> unit
type 'a iter = ('a -> unit) -> unit
(** A sequence of items of type ['a], possibly infinite
@since NEXT_RELEASE *)
type 'a sequence_once = 'a sequence
type 'a iter_once = 'a iter
(** Iter that should be used only once
@since NEXT_RELEASE *)
type 'a sequence = ('a -> unit) -> unit
(** A sequence of items of type ['a], possibly infinite
@deprecate see {!iter} instead *)
[@@ocaml.deprecated "see iter"]
type 'a sequence_once = 'a iter
(** Iter that should be used only once
@deprecate see {!iter_once} instead *)
[@@ocaml.deprecated "see iter_once"]
exception Iter_once
let (|>) x f = f x
module Seq = struct
type 'a t = 'a sequence
module Iter = struct
type 'a t = 'a iter
let return x k = k x
let (>>=) a f k = a (fun x -> f x k)
let map f a k = a (fun x -> k (f x))
@ -31,14 +45,16 @@ module Seq = struct
with Exit_ -> true
end
module Seq = Iter
(** {2 Interfaces for graphs} *)
(** Directed graph with vertices of type ['v] and edges labeled with [e'] *)
type ('v, 'e) t = ('v -> ('e * 'v) sequence)
type ('v, 'e) t = ('v -> ('e * 'v) iter)
type ('v, 'e) graph = ('v, 'e) t
let make (f:'v->('e*'v) sequence): ('v, 'e) t = f
let make (f:'v->('e*'v) iter): ('v, 'e) t = f
(** Mutable bitset for values of type ['v] *)
type 'v tag_set = {
@ -143,66 +159,66 @@ let mk_heap ~leq =
module Traverse = struct
type ('v, 'e) path = ('v * 'e * 'v) list
let generic_tag ~tags ~bag ~graph seq =
let generic_tag ~tags ~bag ~graph iter =
let first = ref true in
fun k ->
(* ensure linearity *)
if !first then first := false else raise Iter_once;
Seq.iter bag.push seq;
Iter.iter bag.push iter;
while not (bag.is_empty ()) do
let x = bag.pop () in
if not (tags.get_tag x) then (
k x;
tags.set_tag x;
Seq.iter
Iter.iter
(fun (_,dest) -> bag.push dest)
(graph x)
)
done
let generic ~tbl ~bag ~graph seq =
let generic ~tbl ~bag ~graph iter =
let tags = {
get_tag=tbl.mem;
set_tag=(fun v -> tbl.add v ());
} in
generic_tag ~tags ~bag ~graph seq
generic_tag ~tags ~bag ~graph iter
let bfs ~tbl ~graph seq =
generic ~tbl ~bag:(mk_queue ()) ~graph seq
let bfs ~tbl ~graph iter =
generic ~tbl ~bag:(mk_queue ()) ~graph iter
let bfs_tag ~tags ~graph seq =
generic_tag ~tags ~bag:(mk_queue()) ~graph seq
let bfs_tag ~tags ~graph iter =
generic_tag ~tags ~bag:(mk_queue()) ~graph iter
let dijkstra_tag ?(dist=fun _ -> 1) ~tags ~graph seq =
let dijkstra_tag ?(dist=fun _ -> 1) ~tags ~graph iter =
let tags' = {
get_tag=(fun (v,_,_) -> tags.get_tag v);
set_tag=(fun (v,_,_) -> tags.set_tag v);
}
and seq' = Seq.map (fun v -> v, 0, []) seq
and iter' = Iter.map (fun v -> v, 0, []) iter
and graph' (v,d,p) =
graph v
|> Seq.map (fun (e,v') -> e, (v',d+dist e, (v,e,v')::p))
|> Iter.map (fun (e,v') -> e, (v',d+dist e, (v,e,v')::p))
in
let bag = mk_heap ~leq:(fun (_,d1,_) (_,d2,_) -> d1 <= d2) in
generic_tag ~tags:tags' ~bag ~graph:graph' seq'
generic_tag ~tags:tags' ~bag ~graph:graph' iter'
let dijkstra ~tbl ?dist ~graph seq =
let dijkstra ~tbl ?dist ~graph iter =
let tags = {
get_tag=tbl.mem;
set_tag=(fun v -> tbl.add v ());
} in
dijkstra_tag ~tags ?dist ~graph seq
dijkstra_tag ~tags ?dist ~graph iter
let dfs ~tbl ~graph seq =
generic ~tbl ~bag:(mk_stack ()) ~graph seq
let dfs ~tbl ~graph iter =
generic ~tbl ~bag:(mk_stack ()) ~graph iter
let dfs_tag ~tags ~graph seq =
generic_tag ~tags ~bag:(mk_stack()) ~graph seq
let dfs_tag ~tags ~graph iter =
generic_tag ~tags ~bag:(mk_stack()) ~graph iter
module Event = struct
type edge_kind = [`Forward | `Back | `Cross ]
(** A traversal is a sequence of such events *)
(** A traversal is a iteruence of such events *)
type ('v,'e) t =
[ `Enter of 'v * int * ('v,'e) path (* unique index in traversal, path from start *)
| `Exit of 'v
@ -240,13 +256,13 @@ module Traverse = struct
| (v1,_,_) :: path' ->
eq v v1 || list_mem_ ~eq ~graph v path'
let dfs_tag ~eq ~tags ~graph seq =
let dfs_tag ~eq ~tags ~graph iter =
let first = ref true in
fun k ->
if !first then first := false else raise Iter_once;
let bag = mk_stack() in
let n = ref 0 in
Seq.iter
Iter.iter
(fun v ->
(* start DFS from this vertex *)
bag.push (`Enter (v, []));
@ -259,7 +275,7 @@ module Traverse = struct
tags.set_tag v;
k (`Enter (v, num, path));
bag.push (`Exit v);
Seq.iter
Iter.iter
(fun (e,v') -> bag.push (`Edge (v,e,v',(v,e,v') :: path)))
(graph v);
)
@ -277,14 +293,14 @@ module Traverse = struct
in
k (`Edge (v,e,v', edge_kind))
done
) seq
) iter
let dfs ~tbl ~eq ~graph seq =
let dfs ~tbl ~eq ~graph iter =
let tags = {
set_tag=(fun v -> tbl.add v ());
get_tag=tbl.mem;
} in
dfs_tag ~eq ~tags ~graph seq
dfs_tag ~eq ~tags ~graph iter
end
(*$R
@ -308,7 +324,7 @@ end
let is_dag ~tbl ~eq ~graph vs =
Traverse.Event.dfs ~tbl ~eq ~graph vs
|> Seq.exists_
|> Iter.exists_
(function
| `Edge (_, _, _, `Back) -> true
| _ -> false)
@ -317,38 +333,38 @@ let is_dag ~tbl ~eq ~graph vs =
exception Has_cycle
let topo_sort_tag ~eq ?(rev=false) ~tags ~graph seq =
let topo_sort_tag ~eq ?(rev=false) ~tags ~graph iter =
(* use DFS *)
let l =
Traverse.Event.dfs_tag ~eq ~tags ~graph seq
|> Seq.filter_map
Traverse.Event.dfs_tag ~eq ~tags ~graph iter
|> Iter.filter_map
(function
| `Exit v -> Some v
| `Edge (_, _, _, `Back) -> raise Has_cycle
| `Enter _
| `Edge _ -> None
)
|> Seq.fold (fun acc x -> x::acc) []
|> Iter.fold (fun acc x -> x::acc) []
in
if rev then List.rev l else l
let topo_sort ~eq ?rev ~tbl ~graph seq =
let topo_sort ~eq ?rev ~tbl ~graph iter =
let tags = {
get_tag=tbl.mem;
set_tag=(fun v -> tbl.add v ());
} in
topo_sort_tag ~eq ?rev ~tags ~graph seq
topo_sort_tag ~eq ?rev ~tags ~graph iter
(*$T
let tbl = mk_table ~eq:CCInt.equal 128 in \
let l = topo_sort ~eq:CCInt.equal ~tbl ~graph:divisors_graph (Seq.return 42) in \
let l = topo_sort ~eq:CCInt.equal ~tbl ~graph:divisors_graph (Iter.return 42) in \
List.for_all (fun (i,j) -> \
let idx_i = CCList.find_idx ((=)i) l |> CCOpt.get_exn |> fst in \
let idx_j = CCList.find_idx ((=)j) l |> CCOpt.get_exn |> fst in \
idx_i < idx_j) \
[ 42, 21; 14, 2; 3, 1; 21, 7; 42, 3]
let tbl = mk_table ~eq:CCInt.equal 128 in \
let l = topo_sort ~eq:CCInt.equal ~rev:true ~tbl ~graph:divisors_graph (Seq.return 42) in \
let l = topo_sort ~eq:CCInt.equal ~rev:true ~tbl ~graph:divisors_graph (Iter.return 42) in \
List.for_all (fun (i,j) -> \
let idx_i = CCList.find_idx ((=)i) l |> CCOpt.get_exn |> fst in \
let idx_j = CCList.find_idx ((=)j) l |> CCOpt.get_exn |> fst in \
@ -381,7 +397,7 @@ end
let spanning_tree_tag ~tags ~graph v =
let rec mk_node v =
let children = lazy (
Seq.fold
Iter.fold
(fun acc (e,v') ->
if tags.get_tag v'
then acc
@ -430,7 +446,7 @@ module SCC = struct
cell.vertex :: acc (* return SCC *)
) else pop_down_to ~id (cell.vertex::acc) stack
let explore ~tbl ~graph seq =
let explore ~tbl ~graph iter =
let first = ref true in
fun k ->
if !first then first := false else raise Iter_once;
@ -441,7 +457,7 @@ module SCC = struct
(* unique ID *)
let n = ref 0 in
(* exploration *)
Seq.iter
Iter.iter
(fun v ->
Stack.push (`Enter v) to_explore;
while not (Stack.is_empty to_explore) do
@ -457,14 +473,14 @@ module SCC = struct
Stack.push cell stack;
Stack.push (`Exit (v, cell)) to_explore;
(* explore children *)
Seq.iter
Iter.iter
(fun (_,v') -> Stack.push (`Enter v') to_explore)
(graph v)
)
| `Exit (v, cell) ->
(* update [min_id] *)
assert cell.on_stack;
Seq.iter
Iter.iter
(fun (_,dest) ->
(* must not fail, [dest] already explored *)
let dest_cell = tbl.find dest in
@ -478,14 +494,14 @@ module SCC = struct
k scc
)
done
) seq;
) iter;
assert (Stack.is_empty stack);
()
end
type 'v scc_state = 'v SCC.state
let scc ~tbl ~graph seq = SCC.explore ~tbl ~graph seq
let scc ~tbl ~graph iter = SCC.explore ~tbl ~graph iter
(* example from https://en.wikipedia.org/wiki/Strongly_connected_component *)
(*$R
@ -507,7 +523,7 @@ let scc ~tbl ~graph seq = SCC.explore ~tbl ~graph seq
; "h", "g"
] in
let tbl = mk_table ~eq:CCString.equal 128 in
let res = scc ~tbl ~graph (Seq.return "a") |> Seq.to_list in
let res = scc ~tbl ~graph (Iter.return "a") |> Iter.to_list in
assert_bool "scc"
(set_eq ~eq:(set_eq ?eq:None) res
[ [ "a"; "b"; "e" ]
@ -544,13 +560,13 @@ module Dot = struct
}
(** Print an enum of Full.traverse_event *)
let pp_seq
let pp_all
~tbl
~eq
?(attrs_v=fun _ -> [])
?(attrs_e=fun _ -> [])
?(name="graph")
~graph out seq =
~graph out iter =
(* print an attribute *)
let pp_attr out attr = match attr with
| `Color c -> Format.fprintf out "color=%s" c
@ -584,8 +600,8 @@ module Dot = struct
get_tag=vertex_explored;
set_tag=set_explored; (* allocate new ID *)
} in
let events = Traverse.Event.dfs_tag ~eq ~tags ~graph seq in
Seq.iter
let events = Traverse.Event.dfs_tag ~eq ~tags ~graph iter in
Iter.iter
(function
| `Enter (v, _n, _path) ->
let attrs = attrs_v v in
@ -602,8 +618,10 @@ module Dot = struct
Format.fprintf out "}@]@;@?";
()
let pp_seq = pp_all
let pp ~tbl ~eq ?attrs_v ?attrs_e ?name ~graph fmt v =
pp_seq ~tbl ~eq ?attrs_v ?attrs_e ?name ~graph fmt (Seq.return v)
pp_all ~tbl ~eq ?attrs_v ?attrs_e ?name ~graph fmt (Iter.return v)
let with_out filename f =
let oc = open_out filename in
@ -652,7 +670,7 @@ module type MAP = sig
type 'a t
val as_graph : 'a t -> (vertex, 'a) graph
(** Graph view of the map *)
(** Graph view of the map. *)
val empty : 'a t
@ -661,16 +679,16 @@ module type MAP = sig
val remove_edge : vertex -> vertex -> 'a t -> 'a t
val add : vertex -> 'a t -> 'a t
(** Add a vertex, possibly with no outgoing edge *)
(** Add a vertex, possibly with no outgoing edge. *)
val remove : vertex -> 'a t -> 'a t
(** Remove the vertex and all its outgoing edges.
Edges that point to the vertex are {b NOT} removed, they must be
manually removed with {!remove_edge} *)
manually removed with {!remove_edge}. *)
val union : 'a t -> 'a t -> 'a t
val vertices : _ t -> vertex sequence
val vertices : _ t -> vertex iter
val vertices_l : _ t -> vertex list
@ -680,11 +698,23 @@ module type MAP = sig
val to_list : 'a t -> (vertex * 'a * vertex) list
val of_seq : (vertex * 'a * vertex) sequence -> 'a t
val of_iter : (vertex * 'a * vertex) iter -> 'a t
(** @since NEXT_RELEASE *)
val add_seq : (vertex * 'a * vertex) sequence -> 'a t -> 'a t
val add_iter : (vertex * 'a * vertex) iter -> 'a t -> 'a t
(** @since NEXT_RELEASE *)
val to_seq : 'a t -> (vertex * 'a * vertex) sequence
val to_iter : 'a t -> (vertex * 'a * vertex) iter
(** @since NEXT_RELEASE *)
val of_seq : (vertex * 'a * vertex) iter -> 'a t
(** @deprecated use {!of_iter} instead *)
val add_seq : (vertex * 'a * vertex) iter -> 'a t -> 'a t
(** @deprecated use {!add_iter} instead *)
val to_seq : 'a t -> (vertex * 'a * vertex) iter
(** @deprecated use {!to_iter} instead *)
end
module Map(O : Map.OrderedType) : MAP with type vertex = O.t = struct
@ -752,11 +782,15 @@ module Map(O : Map.OrderedType) : MAP with type vertex = O.t = struct
(fun v map acc -> M.fold (fun v' e acc -> (v,e,v')::acc) map acc)
m []
let add_seq seq m = Seq.fold (fun m (v1,e,v2) -> add_edge v1 e v2 m) m seq
let add_iter iter m = Iter.fold (fun m (v1,e,v2) -> add_edge v1 e v2 m) m iter
let of_seq seq = add_seq seq empty
let of_iter iter = add_iter iter empty
let to_seq m k = M.iter (fun v map -> M.iter (fun v' e -> k(v,e,v')) map) m
let to_iter m k = M.iter (fun v map -> M.iter (fun v' e -> k(v,e,v')) map) m
let add_seq = add_iter
let of_seq = of_iter
let to_seq = to_iter
end
(** {2 Misc} *)

View file

@ -14,9 +14,11 @@
This abstract notion of graph makes it possible to run the algorithms
on any user-specific type that happens to have a graph structure.
Many graph algorithms here take a sequence of vertices as input.
Many graph algorithms here take an iterator of vertices as input.
The helper module {!Iter} contains basic functions for that, as does
the [iter] library on opam.
If the user only has a single vertex (e.g., for a topological sort
from a given vertex), she can use [Seq.return x] to build a sequence
from a given vertex), they can use [Iter.return x] to build a iter
of one element.
{b status: unstable}
@ -25,18 +27,30 @@
(** {2 Iter Helpers} *)
type 'a iter = ('a -> unit) -> unit
(** A sequence of items of type ['a], possibly infinite
@since NEXT_RELEASE *)
type 'a iter_once = 'a iter
(** Iter that should be used only once
@since NEXT_RELEASE *)
type 'a sequence = ('a -> unit) -> unit
(** A sequence of items of type ['a], possibly infinite *)
(** A sequence of items of type ['a], possibly infinite
@deprecate see {!iter} instead *)
[@@ocaml.deprecated "see iter"]
type 'a sequence_once = 'a sequence
(** Iter that should be used only once *)
(** Iter that should be used only once
@deprecate see {!iter_once} instead *)
[@@ocaml.deprecated "see iter_once"]
exception Iter_once
(** Raised when a sequence meant to be used once is used several times. *)
module Seq : sig
type 'a t = 'a sequence
val return : 'a -> 'a sequence
module Iter : sig
type 'a t = 'a iter
val return : 'a -> 'a t
val (>>=) : 'a t -> ('a -> 'b t) -> 'b t
val map : ('a -> 'b) -> 'a t -> 'b t
val filter_map : ('a -> 'b option) -> 'a t -> 'b t
@ -45,16 +59,20 @@ module Seq : sig
val to_list : 'a t -> 'a list
end
module Seq = Iter
(** @deprecated use {!Iter} instead *)
[@@ocaml.deprecated "use {!Iter} instead"]
(** {2 Interfaces for graphs}
This interface is designed for oriented graphs with labels on edges *)
(** Directed graph with vertices of type ['v] and edges labeled with [e'] *)
type ('v, 'e) t = ('v -> ('e * 'v) sequence)
type ('v, 'e) t = ('v -> ('e * 'v) iter)
type ('v, 'e) graph = ('v, 'e) t
val make : ('v -> ('e * 'v) sequence) -> ('v, 'e) t
val make : ('v -> ('e * 'v) iter) -> ('v, 'e) t
(** Make a graph by providing the children function. *)
(** {2 Tags}
@ -107,8 +125,8 @@ module Traverse : sig
val generic: tbl:'v set ->
bag:'v bag ->
graph:('v, 'e) t ->
'v sequence ->
'v sequence_once
'v iter ->
'v iter_once
(** Traversal of the given graph, starting from a sequence
of vertices, using the given bag to choose the next vertex to
explore. Each vertex is visited at most once. *)
@ -116,35 +134,35 @@ module Traverse : sig
val generic_tag: tags:'v tag_set ->
bag:'v bag ->
graph:('v, 'e) t ->
'v sequence ->
'v sequence_once
'v iter ->
'v iter_once
(** One-shot traversal of the graph using a tag set and the given bag. *)
val dfs: tbl:'v set ->
graph:('v, 'e) t ->
'v sequence ->
'v sequence_once
'v iter ->
'v iter_once
val dfs_tag: tags:'v tag_set ->
graph:('v, 'e) t ->
'v sequence ->
'v sequence_once
'v iter ->
'v iter_once
val bfs: tbl:'v set ->
graph:('v, 'e) t ->
'v sequence ->
'v sequence_once
'v iter ->
'v iter_once
val bfs_tag: tags:'v tag_set ->
graph:('v, 'e) t ->
'v sequence ->
'v sequence_once
'v iter ->
'v iter_once
val dijkstra : tbl:'v set ->
?dist:('e -> int) ->
graph:('v, 'e) t ->
'v sequence ->
('v * int * ('v,'e) path) sequence_once
'v iter ->
('v * int * ('v,'e) path) iter_once
(** Dijkstra algorithm, traverses a graph in increasing distance order.
Yields each vertex paired with its distance to the set of initial vertices
(the smallest distance needed to reach the node from the initial vertices).
@ -154,8 +172,8 @@ module Traverse : sig
val dijkstra_tag : ?dist:('e -> int) ->
tags:'v tag_set ->
graph:('v, 'e) t ->
'v sequence ->
('v * int * ('v,'e) path) sequence_once
'v iter ->
('v * int * ('v,'e) path) iter_once
(** {2 More detailed interface} *)
module Event : sig
@ -177,16 +195,16 @@ module Traverse : sig
val dfs: tbl:'v set ->
eq:('v -> 'v -> bool) ->
graph:('v, 'e) graph ->
'v sequence ->
('v,'e) t sequence_once
'v iter ->
('v,'e) t iter_once
(** Full version of DFS.
@param eq equality predicate on vertices. *)
val dfs_tag: eq:('v -> 'v -> bool) ->
tags:'v tag_set ->
graph:('v, 'e) graph ->
'v sequence ->
('v,'e) t sequence_once
'v iter ->
('v,'e) t iter_once
(** Full version of DFS using integer tags.
@param eq equality predicate on vertices. *)
end
@ -198,7 +216,7 @@ val is_dag :
tbl:'v set ->
eq:('v -> 'v -> bool) ->
graph:('v, _) t ->
'v sequence ->
'v iter ->
bool
(** [is_dag ~graph vs] returns [true] if the subset of [graph] reachable
from [vs] is acyclic.
@ -212,7 +230,7 @@ val topo_sort : eq:('v -> 'v -> bool) ->
?rev:bool ->
tbl:'v set ->
graph:('v, 'e) t ->
'v sequence ->
'v iter ->
'v list
(** [topo_sort ~graph seq] returns a list of vertices [l] where each
element of [l] is reachable from [seq].
@ -229,7 +247,7 @@ val topo_sort_tag : eq:('v -> 'v -> bool) ->
?rev:bool ->
tags:'v tag_set ->
graph:('v, 'e) t ->
'v sequence ->
'v iter ->
'v list
(** Same as {!topo_sort} but uses an explicit tag set. *)
@ -265,8 +283,8 @@ type 'v scc_state
val scc : tbl:('v, 'v scc_state) table ->
graph:('v, 'e) t ->
'v sequence ->
'v list sequence_once
'v iter ->
'v list iter_once
(** Strongly connected components reachable from the given vertices.
Each component is a list of vertices that are all mutually reachable
in the graph.
@ -319,6 +337,18 @@ module Dot : sig
@param attrs_e attributes for edges.
@param name name of the graph. *)
val pp_all : tbl:('v,vertex_state) table ->
eq:('v -> 'v -> bool) ->
?attrs_v:('v -> attribute list) ->
?attrs_e:('e -> attribute list) ->
?name:string ->
graph:('v,'e) t ->
Format.formatter ->
'v iter ->
unit
(** Same as {!pp} but starting from several vertices, not just one.
@since NEXT_RELEASE *)
val pp_seq : tbl:('v,vertex_state) table ->
eq:('v -> 'v -> bool) ->
?attrs_v:('v -> attribute list) ->
@ -326,8 +356,10 @@ module Dot : sig
?name:string ->
graph:('v,'e) t ->
Format.formatter ->
'v sequence ->
'v iter ->
unit
(** @deprecated see {!pp_all} instead *)
[@@ocaml.deprecated "use {!pp_all} instead"]
val with_out : string -> (Format.formatter -> 'a) -> 'a
(** Shortcut to open a file and write to it. *)
@ -377,7 +409,7 @@ module type MAP = sig
val union : 'a t -> 'a t -> 'a t
val vertices : _ t -> vertex sequence
val vertices : _ t -> vertex iter
val vertices_l : _ t -> vertex list
@ -387,11 +419,23 @@ module type MAP = sig
val to_list : 'a t -> (vertex * 'a * vertex) list
val of_seq : (vertex * 'a * vertex) sequence -> 'a t
val of_iter : (vertex * 'a * vertex) iter -> 'a t
(** @since NEXT_RELEASE *)
val add_seq : (vertex * 'a * vertex) sequence -> 'a t -> 'a t
val add_iter : (vertex * 'a * vertex) iter -> 'a t -> 'a t
(** @since NEXT_RELEASE *)
val to_seq : 'a t -> (vertex * 'a * vertex) sequence
val to_iter : 'a t -> (vertex * 'a * vertex) iter
(** @since NEXT_RELEASE *)
val of_seq : (vertex * 'a * vertex) iter -> 'a t
(** @deprecated use {!of_iter} instead *)
val add_seq : (vertex * 'a * vertex) iter -> 'a t -> 'a t
(** @deprecated use {!add_iter} instead *)
val to_seq : 'a t -> (vertex * 'a * vertex) iter
(** @deprecated use {!to_iter} instead *)
end
module Map(O : Map.OrderedType) : MAP with type vertex = O.t