add a few functions to CCArray1

This commit is contained in:
Simon Cruanes 2015-06-24 00:13:40 +02:00
parent 2ced134868
commit a4f0e17799
2 changed files with 40 additions and 14 deletions

View file

@ -81,8 +81,8 @@ let init ~kind ~f n =
done; done;
a a
let of_array a = a let of_bigarray a = a
let to_array a = a let to_bigarray a = a
let ro (t : ('a,'b,[>`R]) t) : ('a,'b,[`R]) t = t let ro (t : ('a,'b,[>`R]) t) : ('a,'b,[`R]) t = t
let wo (t : ('a,'b,[>`W]) t) : ('a,'b,[`W]) t = t let wo (t : ('a,'b,[>`W]) t) : ('a,'b,[`W]) t = t
@ -560,6 +560,24 @@ module Float = struct
include Infix include Infix
end end
let to_list a =
let l = foldi (fun acc _ x -> x::acc) [] a in
List.rev l
let to_array a =
if A.dim a = 0 then [||]
else (
let b = Array.make (A.dim a) (A.get a 0) in
for i = 1 to A.dim a - 1 do
Array.unsafe_set b i (A.unsafe_get a i)
done;
b
)
let to_seq a yield = iter a ~f:yield
let of_array ~kind a = A.of_array kind Bigarray.c_layout a
exception OfYojsonError of string exception OfYojsonError of string
let to_yojson (f:'a -> json) a : json = let to_yojson (f:'a -> json) a : json =
@ -672,7 +690,7 @@ module View = struct
let select_a ~idx a = {len=Array.length idx; view=SelectA(idx,a)} let select_a ~idx a = {len=Array.length idx; view=SelectA(idx,a)}
let select_view ~idx a = {len=length idx; view=SelectV(idx,a)} let select_view ~idx a = {len=length idx; view=SelectV(idx,a)}
let fold f acc a = let foldi f acc a =
let acc = ref acc in let acc = ref acc in
iteri a ~f:(fun i x -> acc := f !acc i x); iteri a ~f:(fun i x -> acc := f !acc i x);
!acc !acc
@ -693,8 +711,8 @@ module View = struct
type elt = int type elt = int
let add a b = map2 ~f:(+) a b let add a b = map2 ~f:(+) a b
let mult a b = map2 ~f:( * ) a b let mult a b = map2 ~f:( * ) a b
let sum a = fold (fun acc _ x -> acc+x) 0 a let sum a = foldi (fun acc _ x -> acc+x) 0 a
let prod a = fold (fun acc _ x -> acc*x) 1 a let prod a = foldi (fun acc _ x -> acc*x) 1 a
let add_scalar a ~x = map ~f:(fun y -> x+y) a let add_scalar a ~x = map ~f:(fun y -> x+y) a
let mult_scalar a ~x = map ~f:(fun y -> x*y) a let mult_scalar a ~x = map ~f:(fun y -> x*y) a
end end
@ -703,8 +721,8 @@ module View = struct
type elt = float type elt = float
let add a b = map2 ~f:(+.) a b let add a b = map2 ~f:(+.) a b
let mult a b = map2 ~f:( *. ) a b let mult a b = map2 ~f:( *. ) a b
let sum a = fold (fun acc _ x -> acc+.x) 0. a let sum a = foldi (fun acc _ x -> acc+.x) 0. a
let prod a = fold (fun acc _ x -> acc*.x) 1. a let prod a = foldi (fun acc _ x -> acc*.x) 1. a
let add_scalar a ~x = map ~f:(fun y -> x+.y) a let add_scalar a ~x = map ~f:(fun y -> x+.y) a
let mult_scalar a ~x = map ~f:(fun y -> x*.y) a let mult_scalar a ~x = map ~f:(fun y -> x*.y) a
end end
@ -720,5 +738,4 @@ module View = struct
in in
iteri a ~f:(fun i x -> A.unsafe_set res i x); iteri a ~f:(fun i x -> A.unsafe_set res i x);
res res
end end

View file

@ -24,7 +24,9 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*) *)
(** {1 Bigarrays of dimension 1 *) (** {1 Bigarrays of dimension 1}
@since NEXT_RELEASE *)
(** {2 used types} *) (** {2 used types} *)
@ -79,10 +81,10 @@ val make_complex64 : int -> (Complex.t, Bigarray.complex64_elt, 'perm) t
val init : kind:('a, 'b) Bigarray.kind -> f:(int -> 'a) -> int -> ('a, 'b, 'perm) t val init : kind:('a, 'b) Bigarray.kind -> f:(int -> 'a) -> int -> ('a, 'b, 'perm) t
(** Initialize with given size and initialization function *) (** Initialize with given size and initialization function *)
val of_array : ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t -> ('a, 'b, 'perm) t val of_bigarray : ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t -> ('a, 'b, 'perm) t
(** Convert from an array *) (** Convert from a big array *)
val to_array : ('a, 'b, [`R | `W]) t -> ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t val to_bigarray : ('a, 'b, [`R | `W]) t -> ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t
(** Obtain the underlying array *) (** Obtain the underlying array *)
val ro : ('a, 'b, [>`R]) t -> ('a, 'b, [`R]) t val ro : ('a, 'b, [>`R]) t -> ('a, 'b, [`R]) t
@ -274,6 +276,14 @@ module Int : S with type elt = int
module Float : S with type elt = float module Float : S with type elt = float
(** {2 Conversions} *)
val to_list : ('a, _, [>`R]) t -> 'a list
val to_array : ('a, _, [>`R]) t -> 'a array
val to_seq : ('a, _, [>`R]) t -> 'a sequence
val of_array : kind:('a, 'b) Bigarray.kind -> 'a array -> ('a, 'b, 'perm) t
(** {2 Serialization} *) (** {2 Serialization} *)
val to_yojson : 'a to_json -> ('a, _, [>`R]) t to_json val to_yojson : 'a to_json -> ('a, _, [>`R]) t to_json
@ -317,7 +327,7 @@ module View : sig
val select_view : idx:int t -> 'a t -> 'a t val select_view : idx:int t -> 'a t -> 'a t
(** See {!select} *) (** See {!select} *)
val fold : ('b -> int -> 'a -> 'b) -> 'b -> 'a t -> 'b val foldi : ('b -> int -> 'a -> 'b) -> 'b -> 'a t -> 'b
(** fold on values with their index *) (** fold on values with their index *)
val iteri : f:(int -> 'a -> unit) -> 'a t -> unit val iteri : f:(int -> 'a -> unit) -> 'a t -> unit
@ -355,7 +365,6 @@ module View : sig
('a, 'b, 'perm) array_ ('a, 'b, 'perm) array_
(** [to_array v] returns a fresh copy of the content of [v]. (** [to_array v] returns a fresh copy of the content of [v].
Exactly one of [res] and [kind] must be provided *) Exactly one of [res] and [kind] must be provided *)
end end