1use crate::error::OpResult;
5use crate::error_helpers::try_from_numeric;
6use crate::ndarray;
7
8use crate::Float;
9
10pub type NdArray<T> = scirs2_core::ndarray::Array<T, scirs2_core::ndarray::IxDyn>;
12
13pub type NdArrayView<'a, T> = scirs2_core::ndarray::ArrayView<'a, T, scirs2_core::ndarray::IxDyn>;
15
16pub type RawNdArrayView<T> = scirs2_core::ndarray::RawArrayView<T, scirs2_core::ndarray::IxDyn>;
18
19pub type RawNdArrayViewMut<T> =
21 scirs2_core::ndarray::RawArrayViewMut<T, scirs2_core::ndarray::IxDyn>;
22
23pub type NdArrayViewMut<'a, T> =
25 scirs2_core::ndarray::ArrayViewMut<'a, T, scirs2_core::ndarray::IxDyn>;
26
27#[inline]
28pub(crate) fn asshape<T: Float>(x: &NdArrayView<T>) -> Vec<usize> {
30 x.iter().map(|a| a.to_usize().unwrap_or(0)).collect()
31}
32
33#[inline]
34pub(crate) fn expand_dims<T: Float>(x: NdArray<T>, axis: usize) -> NdArray<T> {
35 let mut shape = x.shape().to_vec();
36 shape.insert(axis, 1);
37 x.into_shape_with_order(shape)
38 .expect("Shape conversion failed - this is a bug")
39}
40
41#[inline]
42pub(crate) fn roll_axis<T: Float>(
43 arg: &mut NdArray<T>,
44 to: scirs2_core::ndarray::Axis,
45 from: scirs2_core::ndarray::Axis,
46) {
47 let i = to.index();
48 let mut j = from.index();
49 if j > i {
50 while i != j {
51 arg.swap_axes(i, j);
52 j -= 1;
53 }
54 } else {
55 while i != j {
56 arg.swap_axes(i, j);
57 j += 1;
58 }
59 }
60}
61
62#[inline]
63pub(crate) fn normalize_negative_axis(axis: isize, ndim: usize) -> usize {
64 if axis < 0 {
65 (ndim as isize + axis) as usize
66 } else {
67 axis as usize
68 }
69}
70
71#[inline]
72pub(crate) fn normalize_negative_axes<T: Float>(axes: &NdArrayView<T>, ndim: usize) -> Vec<usize> {
73 let mut axes_ret: Vec<usize> = Vec::with_capacity(axes.len());
74 for &axis in axes.iter() {
75 let axis = if axis < T::zero() {
76 (T::from(ndim).unwrap_or_else(|| T::zero()) + axis)
77 .to_usize()
78 .unwrap_or(0)
79 } else {
80 axis.to_usize().unwrap_or(0)
81 };
82 axes_ret.push(axis);
83 }
84 axes_ret
85}
86
87#[inline]
88pub(crate) fn sparse_to_dense<T: Float>(arr: &NdArrayView<T>) -> Vec<usize> {
89 let mut axes: Vec<usize> = vec![];
90 for (i, &a) in arr.iter().enumerate() {
91 if a == T::one() {
92 axes.push(i);
93 }
94 }
95 axes
96}
97
98#[allow(unused)]
99#[inline]
100pub(crate) fn is_fully_transposed(strides: &[scirs2_core::ndarray::Ixs]) -> bool {
101 let mut ret = true;
102 for w in strides.windows(2) {
103 if w[0] > w[1] {
104 ret = false;
105 break;
106 }
107 }
108 ret
109}
110
111#[inline]
113#[allow(dead_code)]
114pub fn zeros<T: Float>(shape: &[usize]) -> NdArray<T> {
115 NdArray::<T>::zeros(shape)
116}
117
118#[inline]
120#[allow(dead_code)]
121pub fn ones<T: Float>(shape: &[usize]) -> NdArray<T> {
122 NdArray::<T>::ones(shape)
123}
124
125#[inline]
127#[allow(dead_code)]
128pub fn constant<T: Float>(value: T, shape: &[usize]) -> NdArray<T> {
129 NdArray::<T>::from_elem(shape, value)
130}
131
132use scirs2_core::random::{Rng, RngCore, SeedableRng};
133#[derive(Clone)]
137pub struct ArrayRng<A> {
138 rng: scirs2_core::random::rngs::StdRng,
139 _phantom: std::marker::PhantomData<A>,
140}
141
142impl<A> RngCore for ArrayRng<A> {
144 fn next_u32(&mut self) -> u32 {
145 self.rng.next_u32()
146 }
147
148 fn next_u64(&mut self) -> u64 {
149 self.rng.next_u64()
150 }
151
152 fn fill_bytes(&mut self, dest: &mut [u8]) {
153 self.rng.fill_bytes(dest)
154 }
155}
156
157impl<A: Float> ArrayRng<A> {
162 pub fn new() -> Self {
164 Self::from_seed(0)
165 }
166
167 pub fn from_seed(seed: u64) -> Self {
169 let rng = scirs2_core::random::rngs::StdRng::seed_from_u64(seed);
170 Self {
171 rng,
172 _phantom: std::marker::PhantomData,
173 }
174 }
175
176 pub fn as_rng(&self) -> &scirs2_core::random::rngs::StdRng {
178 &self.rng
179 }
180
181 pub fn as_rng_mut(&mut self) -> &mut scirs2_core::random::rngs::StdRng {
183 &mut self.rng
184 }
185
186 pub fn random(&mut self, shape: &[usize]) -> NdArray<A> {
189 let len = shape.iter().product();
190 let mut data = Vec::with_capacity(len);
191 for _ in 0..len {
192 data.push(
193 A::from(self.rng.random::<f64>()).expect("Shape conversion failed - this is a bug"),
194 );
195 }
196 NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
197 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
198 }
199
200 pub fn normal(&mut self, shape: &[usize], mean: f64, std: f64) -> NdArray<A> {
203 use scirs2_core::random::{Distribution, Normal};
204 let normal = Normal::new(mean, std)
205 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
206 let len = shape.iter().product();
207 let mut data = Vec::with_capacity(len);
208 for _ in 0..len {
209 data.push(
210 A::from(normal.sample(&mut self.rng))
211 .expect("Shape conversion failed - this is a bug"),
212 );
213 }
214 NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
215 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
216 }
217
218 pub fn uniform(&mut self, shape: &[usize], low: f64, high: f64) -> NdArray<A> {
221 use scirs2_core::random::{Distribution, Uniform};
222 let uniform = Uniform::new(low, high)
223 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
224 let len = shape.iter().product();
225 let mut data = Vec::with_capacity(len);
226 for _ in 0..len {
227 data.push(
228 A::from(uniform.sample(&mut self.rng))
229 .expect("Shape conversion failed - this is a bug"),
230 );
231 }
232 NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
233 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
234 }
235
236 pub fn glorot_uniform(&mut self, shape: &[usize]) -> NdArray<A> {
240 assert!(shape.len() >= 2, "shape must have at least 2 dimensions");
241 let fan_in = shape[shape.len() - 2];
242 let fan_out = shape[shape.len() - 1];
243 let scale = (6.0 / (fan_in + fan_out) as f64).sqrt();
244 self.uniform(shape, -scale, scale)
245 }
246
247 pub fn glorot_normal(&mut self, shape: &[usize]) -> NdArray<A> {
251 assert!(shape.len() >= 2, "shape must have at least 2 dimensions");
252 let fan_in = shape[shape.len() - 2];
253 let fan_out = shape[shape.len() - 1];
254 let scale = (2.0 / (fan_in + fan_out) as f64).sqrt();
255 self.normal(shape, 0.0, scale)
256 }
257
258 pub fn he_uniform(&mut self, shape: &[usize]) -> NdArray<A> {
262 assert!(shape.len() >= 2, "shape must have at least 2 dimensions");
263 let fan_in = shape[shape.len() - 2];
264 let scale = (6.0 / fan_in as f64).sqrt();
265 self.uniform(shape, -scale, scale)
266 }
267
268 pub fn he_normal(&mut self, shape: &[usize]) -> NdArray<A> {
272 assert!(shape.len() >= 2, "shape must have at least 2 dimensions");
273 let fan_in = shape[shape.len() - 2];
274 let scale = (2.0 / fan_in as f64).sqrt();
275 self.normal(shape, 0.0, scale)
276 }
277
278 pub fn standard_normal(&mut self, shape: &[usize]) -> NdArray<A> {
280 self.normal(shape, 0.0, 1.0)
281 }
282
283 pub fn standard_uniform(&mut self, shape: &[usize]) -> NdArray<A> {
285 self.uniform(shape, 0.0, 1.0)
286 }
287
288 pub fn bernoulli(&mut self, shape: &[usize], p: f64) -> NdArray<A> {
290 use scirs2_core::random::{Bernoulli, Distribution};
291 let bernoulli =
292 Bernoulli::new(p).unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
293 let len = shape.iter().product();
294 let mut data = Vec::with_capacity(len);
295 for _ in 0..len {
296 let val = if bernoulli.sample(&mut self.rng) {
297 A::one()
298 } else {
299 A::zero()
300 };
301 data.push(val);
302 }
303 NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
304 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
305 }
306
307 pub fn exponential(&mut self, shape: &[usize], lambda: f64) -> NdArray<A> {
309 use scirs2_core::random::{Distribution, Exp};
310 let exp =
311 Exp::new(lambda).unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
312 let len = shape.iter().product();
313 let mut data = Vec::with_capacity(len);
314 for _ in 0..len {
315 data.push(
316 A::from(exp.sample(&mut self.rng))
317 .expect("Shape conversion failed - this is a bug"),
318 );
319 }
320 NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
321 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
322 }
323
324 pub fn log_normal(&mut self, shape: &[usize], mean: f64, stddev: f64) -> NdArray<A> {
326 use scirs2_core::random::{Distribution, LogNormal};
327 let log_normal = LogNormal::new(mean, stddev)
328 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
329 let len = shape.iter().product();
330 let mut data = Vec::with_capacity(len);
331 for _ in 0..len {
332 data.push(
333 A::from(log_normal.sample(&mut self.rng))
334 .expect("Shape conversion failed - this is a bug"),
335 );
336 }
337 NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
338 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
339 }
340
341 pub fn gamma(&mut self, shape: &[usize], shape_param: f64, scale: f64) -> NdArray<A> {
343 use scirs2_core::random::{Distribution, Gamma};
344 let gamma = Gamma::new(shape_param, scale)
345 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
346 let len = shape.iter().product();
347 let mut data = Vec::with_capacity(len);
348 for _ in 0..len {
349 data.push(
350 A::from(gamma.sample(&mut self.rng))
351 .expect("Shape conversion failed - this is a bug"),
352 );
353 }
354 NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
355 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
356 }
357}
358
359impl<A: Float> Default for ArrayRng<A> {
360 fn default() -> Self {
361 Self::new()
362 }
363}
364
365#[inline]
367#[allow(dead_code)]
368pub fn is_scalarshape(shape: &[usize]) -> bool {
369 shape.is_empty() || (shape.len() == 1 && shape[0] == 1)
370}
371
372#[inline]
374#[allow(dead_code)]
375pub fn scalarshape() -> Vec<usize> {
376 vec![]
377}
378
379#[inline]
381#[allow(dead_code)]
382pub fn from_scalar<T: Float>(value: T) -> NdArray<T> {
383 NdArray::<T>::from_elem(scirs2_core::ndarray::IxDyn(&[1]), value)
384}
385
386#[inline]
388#[allow(dead_code)]
389pub fn shape_of_view<T>(view: &NdArrayView<'_, T>) -> Vec<usize> {
390 view.shape().to_vec()
391}
392
393#[inline]
395#[allow(dead_code)]
396pub fn shape_of<T>(array: &NdArray<T>) -> Vec<usize> {
397 array.shape().to_vec()
398}
399
400#[inline]
402#[allow(dead_code)]
403pub fn get_default_rng<A: Float>() -> ArrayRng<A> {
404 ArrayRng::<A>::default()
405}
406
407#[inline]
409#[allow(dead_code)]
410pub fn deep_copy<T: Float + Clone>(array: &NdArrayView<'_, T>) -> NdArray<T> {
411 array.to_owned()
412}
413
414#[inline]
416#[allow(dead_code)]
417pub fn select<T: Float + Clone>(
418 array: &NdArrayView<'_, T>,
419 axis: scirs2_core::ndarray::Axis,
420 indices: &[usize],
421) -> NdArray<T> {
422 let mut shape = array.shape().to_vec();
423 shape[axis.index()] = indices.len();
424
425 let mut result = NdArray::<T>::zeros(scirs2_core::ndarray::IxDyn(&shape));
426
427 for (i, &idx) in indices.iter().enumerate() {
428 let slice = array.index_axis(axis, idx);
429 result.index_axis_mut(axis, i).assign(&slice);
430 }
431
432 result
433}
434
435#[inline]
437#[allow(dead_code)]
438pub fn are_broadcast_compatible(shape1: &[usize], shape2: &[usize]) -> bool {
439 let len1 = shape1.len();
440 let len2 = shape2.len();
441 let min_len = std::cmp::min(len1, len2);
442
443 for i in 0..min_len {
444 let dim1 = shape1[len1 - 1 - i];
445 let dim2 = shape2[len2 - 1 - i];
446 if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
447 return false;
448 }
449 }
450 true
451}
452
453#[inline]
455#[allow(dead_code)]
456pub fn broadcastshape(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
457 if !are_broadcast_compatible(shape1, shape2) {
458 return None;
459 }
460
461 let len1 = shape1.len();
462 let len2 = shape2.len();
463 let result_len = std::cmp::max(len1, len2);
464 let mut result = Vec::with_capacity(result_len);
465
466 for i in 0..result_len {
467 let dim1 = if i < len1 { shape1[len1 - 1 - i] } else { 1 };
468 let dim2 = if i < len2 { shape2[len2 - 1 - i] } else { 1 };
469 result.push(std::cmp::max(dim1, dim2));
470 }
471
472 result.reverse();
473 Some(result)
474}
475
476pub mod array_gen {
478 use super::*;
479
480 #[inline]
482 pub fn zeros<T: Float>(shape: &[usize]) -> NdArray<T> {
483 NdArray::<T>::zeros(shape)
484 }
485
486 #[inline]
488 pub fn ones<T: Float>(shape: &[usize]) -> NdArray<T> {
489 NdArray::<T>::ones(shape)
490 }
491
492 #[inline]
494 pub fn eye<T: Float>(n: usize) -> NdArray<T> {
495 let mut result = NdArray::<T>::zeros(scirs2_core::ndarray::IxDyn(&[n, n]));
496 for i in 0..n {
497 result[[i, i]] = T::one();
498 }
499 result
500 }
501
502 #[inline]
504 pub fn constant<T: Float>(value: T, shape: &[usize]) -> NdArray<T> {
505 NdArray::<T>::from_elem(shape, value)
506 }
507
508 pub fn random<T: Float>(shape: &[usize]) -> NdArray<T> {
510 let mut rng = ArrayRng::<T>::default();
511 rng.random(shape)
512 }
513
514 pub fn randn<T: Float>(shape: &[usize]) -> NdArray<T> {
516 let mut rng = ArrayRng::<T>::default();
517 rng.normal(shape, 0.0, 1.0)
518 }
519
520 pub fn glorot_uniform<T: Float>(shape: &[usize]) -> NdArray<T> {
522 let mut rng = ArrayRng::<T>::default();
523 rng.glorot_uniform(shape)
524 }
525
526 pub fn glorot_normal<T: Float>(shape: &[usize]) -> NdArray<T> {
528 let mut rng = ArrayRng::<T>::default();
529 rng.glorot_normal(shape)
530 }
531
532 pub fn he_uniform<T: Float>(shape: &[usize]) -> NdArray<T> {
534 let mut rng = ArrayRng::<T>::default();
535 rng.he_uniform(shape)
536 }
537
538 pub fn he_normal<T: Float>(shape: &[usize]) -> NdArray<T> {
540 let mut rng = ArrayRng::<T>::default();
541 rng.he_normal(shape)
542 }
543
544 pub fn linspace<T: Float>(start: T, end: T, num: usize) -> NdArray<T> {
546 if num <= 1 {
547 return if num == 0 {
548 NdArray::<T>::zeros(scirs2_core::ndarray::IxDyn(&[0]))
549 } else {
550 NdArray::<T>::from_elem(scirs2_core::ndarray::IxDyn(&[1]), start)
551 };
552 }
553
554 let step = (end - start) / T::from(num - 1).unwrap_or_else(|| T::one());
555 let mut data = Vec::with_capacity(num);
556
557 for i in 0..num {
558 data.push(start + step * T::from(i).unwrap_or_else(|| T::zero()));
559 }
560
561 NdArray::<T>::from_shape_vec(scirs2_core::ndarray::IxDyn(&[num]), data)
562 .expect("Shape conversion failed - this is a bug")
563 }
564
565 pub fn arange<T: Float>(start: T, end: T, step: T) -> NdArray<T> {
567 let size = ((end - start) / step).to_f64().unwrap_or(0.0).ceil() as usize;
568 let mut data = Vec::with_capacity(size);
569
570 let mut current = start;
571 while current < end {
572 data.push(current);
573 current += step;
574 }
575
576 NdArray::<T>::from_shape_vec(scirs2_core::ndarray::IxDyn(&[data.len()]), data)
577 .expect("Shape conversion failed - this is a bug")
578 }
579}