Skip to main content

scirs2_autograd/
ndarray_ext.rs

1//! A small extension of [ndarray](https://github.com/rust-ndarray/ndarray)
2//!
3//! Mainly provides `array_gen`, which is a collection of array generator functions.
4use crate::error::OpResult;
5use crate::error_helpers::try_from_numeric;
6use crate::ndarray;
7
8use crate::Float;
9
10/// alias for `scirs2_core::ndarray::Array<T, IxDyn>`
11pub type NdArray<T> = scirs2_core::ndarray::Array<T, scirs2_core::ndarray::IxDyn>;
12
13/// alias for `scirs2_core::ndarray::ArrayView<T, IxDyn>`
14pub type NdArrayView<'a, T> = scirs2_core::ndarray::ArrayView<'a, T, scirs2_core::ndarray::IxDyn>;
15
16/// alias for `scirs2_core::ndarray::RawArrayView<T, IxDyn>`
17pub type RawNdArrayView<T> = scirs2_core::ndarray::RawArrayView<T, scirs2_core::ndarray::IxDyn>;
18
19/// alias for `scirs2_core::ndarray::RawArrayViewMut<T, IxDyn>`
20pub type RawNdArrayViewMut<T> =
21    scirs2_core::ndarray::RawArrayViewMut<T, scirs2_core::ndarray::IxDyn>;
22
23/// alias for `scirs2_core::ndarray::ArrayViewMut<T, IxDyn>`
24pub type NdArrayViewMut<'a, T> =
25    scirs2_core::ndarray::ArrayViewMut<'a, T, scirs2_core::ndarray::IxDyn>;
26
27#[inline]
28/// This works well only for small arrays
29pub(crate) fn asshape<T: Float>(x: &NdArrayView<T>) -> Vec<usize> {
30    x.iter().map(|a| a.to_usize().unwrap_or(0)).collect()
31}
32
33#[inline]
34pub(crate) fn expand_dims<T: Float>(x: NdArray<T>, axis: usize) -> NdArray<T> {
35    let mut shape = x.shape().to_vec();
36    shape.insert(axis, 1);
37    x.into_shape_with_order(shape)
38        .expect("Shape conversion failed - this is a bug")
39}
40
41#[inline]
42pub(crate) fn roll_axis<T: Float>(
43    arg: &mut NdArray<T>,
44    to: scirs2_core::ndarray::Axis,
45    from: scirs2_core::ndarray::Axis,
46) {
47    let i = to.index();
48    let mut j = from.index();
49    if j > i {
50        while i != j {
51            arg.swap_axes(i, j);
52            j -= 1;
53        }
54    } else {
55        while i != j {
56            arg.swap_axes(i, j);
57            j += 1;
58        }
59    }
60}
61
62#[inline]
63pub(crate) fn normalize_negative_axis(axis: isize, ndim: usize) -> usize {
64    if axis < 0 {
65        (ndim as isize + axis) as usize
66    } else {
67        axis as usize
68    }
69}
70
71#[inline]
72pub(crate) fn normalize_negative_axes<T: Float>(axes: &NdArrayView<T>, ndim: usize) -> Vec<usize> {
73    let mut axes_ret: Vec<usize> = Vec::with_capacity(axes.len());
74    for &axis in axes.iter() {
75        let axis = if axis < T::zero() {
76            (T::from(ndim).unwrap_or_else(|| T::zero()) + axis)
77                .to_usize()
78                .unwrap_or(0)
79        } else {
80            axis.to_usize().unwrap_or(0)
81        };
82        axes_ret.push(axis);
83    }
84    axes_ret
85}
86
87#[inline]
88pub(crate) fn sparse_to_dense<T: Float>(arr: &NdArrayView<T>) -> Vec<usize> {
89    let mut axes: Vec<usize> = vec![];
90    for (i, &a) in arr.iter().enumerate() {
91        if a == T::one() {
92            axes.push(i);
93        }
94    }
95    axes
96}
97
98#[allow(unused)]
99#[inline]
100pub(crate) fn is_fully_transposed(strides: &[scirs2_core::ndarray::Ixs]) -> bool {
101    let mut ret = true;
102    for w in strides.windows(2) {
103        if w[0] > w[1] {
104            ret = false;
105            break;
106        }
107    }
108    ret
109}
110
111/// Creates a zero array in the specified shape.
112#[inline]
113#[allow(dead_code)]
114pub fn zeros<T: Float>(shape: &[usize]) -> NdArray<T> {
115    NdArray::<T>::zeros(shape)
116}
117
118/// Creates a one array in the specified shape.
119#[inline]
120#[allow(dead_code)]
121pub fn ones<T: Float>(shape: &[usize]) -> NdArray<T> {
122    NdArray::<T>::ones(shape)
123}
124
125/// Creates a constant array in the specified shape.
126#[inline]
127#[allow(dead_code)]
128pub fn constant<T: Float>(value: T, shape: &[usize]) -> NdArray<T> {
129    NdArray::<T>::from_elem(shape, value)
130}
131
132use scirs2_core::random::{Rng, RngCore, SeedableRng};
133// In rand 0.9.0, RngCore doesn't need rand_core imported directly
134
135/// Random number generator for ndarray
136#[derive(Clone)]
137pub struct ArrayRng<A> {
138    rng: scirs2_core::random::rngs::StdRng,
139    _phantom: std::marker::PhantomData<A>,
140}
141
142// Implement RngCore for ArrayRng by delegating to the internal StdRng
143impl<A> RngCore for ArrayRng<A> {
144    fn next_u32(&mut self) -> u32 {
145        self.rng.next_u32()
146    }
147
148    fn next_u64(&mut self) -> u64 {
149        self.rng.next_u64()
150    }
151
152    fn fill_bytes(&mut self, dest: &mut [u8]) {
153        self.rng.fill_bytes(dest)
154    }
155}
156
157// Don't implement Rng directly since there's a blanket impl in rand crate
158// This was causing conflict with the blanket implementation
159// impl<A> Rng for ArrayRng<A> {}
160
161impl<A: Float> ArrayRng<A> {
162    /// Creates a new random number generator with the default seed.
163    pub fn new() -> Self {
164        Self::from_seed(0)
165    }
166
167    /// Creates a new random number generator with the specified seed.
168    pub fn from_seed(seed: u64) -> Self {
169        let rng = scirs2_core::random::rngs::StdRng::seed_from_u64(seed);
170        Self {
171            rng,
172            _phantom: std::marker::PhantomData,
173        }
174    }
175
176    /// Returns a reference to the internal RNG
177    pub fn as_rng(&self) -> &scirs2_core::random::rngs::StdRng {
178        &self.rng
179    }
180
181    /// Returns a mutable reference to the internal RNG
182    pub fn as_rng_mut(&mut self) -> &mut scirs2_core::random::rngs::StdRng {
183        &mut self.rng
184    }
185
186    /// Creates a uniform random array in the specified shape.
187    /// Values are in the range [0, 1)
188    pub fn random(&mut self, shape: &[usize]) -> NdArray<A> {
189        let len = shape.iter().product();
190        let mut data = Vec::with_capacity(len);
191        for _ in 0..len {
192            data.push(
193                A::from(self.rng.random::<f64>()).expect("Shape conversion failed - this is a bug"),
194            );
195        }
196        NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
197            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
198    }
199
200    /// Creates a normal random array in the specified shape.
201    /// Values are drawn from a normal distribution with the specified mean and standard deviation.
202    pub fn normal(&mut self, shape: &[usize], mean: f64, std: f64) -> NdArray<A> {
203        use scirs2_core::random::{Distribution, Normal};
204        let normal = Normal::new(mean, std)
205            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
206        let len = shape.iter().product();
207        let mut data = Vec::with_capacity(len);
208        for _ in 0..len {
209            data.push(
210                A::from(normal.sample(&mut self.rng))
211                    .expect("Shape conversion failed - this is a bug"),
212            );
213        }
214        NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
215            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
216    }
217
218    /// Creates a uniform random array in the specified shape.
219    /// Values are in the range [low, high).
220    pub fn uniform(&mut self, shape: &[usize], low: f64, high: f64) -> NdArray<A> {
221        use scirs2_core::random::{Distribution, Uniform};
222        let uniform = Uniform::new(low, high)
223            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
224        let len = shape.iter().product();
225        let mut data = Vec::with_capacity(len);
226        for _ in 0..len {
227            data.push(
228                A::from(uniform.sample(&mut self.rng))
229                    .expect("Shape conversion failed - this is a bug"),
230            );
231        }
232        NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
233            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
234    }
235
236    /// Creates a random array with Glorot/Xavier uniform initialization.
237    /// For a tensor with shape (in_features, out_features),
238    /// samples are drawn from Uniform(-sqrt(6/(in_features+out_features)), sqrt(6/(in_features+out_features))).
239    pub fn glorot_uniform(&mut self, shape: &[usize]) -> NdArray<A> {
240        assert!(shape.len() >= 2, "shape must have at least 2 dimensions");
241        let fan_in = shape[shape.len() - 2];
242        let fan_out = shape[shape.len() - 1];
243        let scale = (6.0 / (fan_in + fan_out) as f64).sqrt();
244        self.uniform(shape, -scale, scale)
245    }
246
247    /// Creates a random array with Glorot/Xavier normal initialization.
248    /// For a tensor with shape (in_features, out_features),
249    /// samples are drawn from Normal(0, sqrt(2/(in_features+out_features))).
250    pub fn glorot_normal(&mut self, shape: &[usize]) -> NdArray<A> {
251        assert!(shape.len() >= 2, "shape must have at least 2 dimensions");
252        let fan_in = shape[shape.len() - 2];
253        let fan_out = shape[shape.len() - 1];
254        let scale = (2.0 / (fan_in + fan_out) as f64).sqrt();
255        self.normal(shape, 0.0, scale)
256    }
257
258    /// Creates a random array with He/Kaiming uniform initialization.
259    /// For a tensor with shape (in_features, out_features),
260    /// samples are drawn from Uniform(-sqrt(6/in_features), sqrt(6/in_features)).
261    pub fn he_uniform(&mut self, shape: &[usize]) -> NdArray<A> {
262        assert!(shape.len() >= 2, "shape must have at least 2 dimensions");
263        let fan_in = shape[shape.len() - 2];
264        let scale = (6.0 / fan_in as f64).sqrt();
265        self.uniform(shape, -scale, scale)
266    }
267
268    /// Creates a random array with He/Kaiming normal initialization.
269    /// For a tensor with shape (in_features, out_features),
270    /// samples are drawn from Normal(0, sqrt(2/in_features)).
271    pub fn he_normal(&mut self, shape: &[usize]) -> NdArray<A> {
272        assert!(shape.len() >= 2, "shape must have at least 2 dimensions");
273        let fan_in = shape[shape.len() - 2];
274        let scale = (2.0 / fan_in as f64).sqrt();
275        self.normal(shape, 0.0, scale)
276    }
277
278    /// Creates a random array from the standard normal distribution.
279    pub fn standard_normal(&mut self, shape: &[usize]) -> NdArray<A> {
280        self.normal(shape, 0.0, 1.0)
281    }
282
283    /// Creates a random array from the standard uniform distribution.
284    pub fn standard_uniform(&mut self, shape: &[usize]) -> NdArray<A> {
285        self.uniform(shape, 0.0, 1.0)
286    }
287
288    /// Creates a random array from the bernoulli distribution.
289    pub fn bernoulli(&mut self, shape: &[usize], p: f64) -> NdArray<A> {
290        use scirs2_core::random::{Bernoulli, Distribution};
291        let bernoulli =
292            Bernoulli::new(p).unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
293        let len = shape.iter().product();
294        let mut data = Vec::with_capacity(len);
295        for _ in 0..len {
296            let val = if bernoulli.sample(&mut self.rng) {
297                A::one()
298            } else {
299                A::zero()
300            };
301            data.push(val);
302        }
303        NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
304            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
305    }
306
307    /// Creates a random array from the exponential distribution.
308    pub fn exponential(&mut self, shape: &[usize], lambda: f64) -> NdArray<A> {
309        use scirs2_core::random::{Distribution, Exp};
310        let exp =
311            Exp::new(lambda).unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
312        let len = shape.iter().product();
313        let mut data = Vec::with_capacity(len);
314        for _ in 0..len {
315            data.push(
316                A::from(exp.sample(&mut self.rng))
317                    .expect("Shape conversion failed - this is a bug"),
318            );
319        }
320        NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
321            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
322    }
323
324    /// Creates a random array from the log-normal distribution.
325    pub fn log_normal(&mut self, shape: &[usize], mean: f64, stddev: f64) -> NdArray<A> {
326        use scirs2_core::random::{Distribution, LogNormal};
327        let log_normal = LogNormal::new(mean, stddev)
328            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
329        let len = shape.iter().product();
330        let mut data = Vec::with_capacity(len);
331        for _ in 0..len {
332            data.push(
333                A::from(log_normal.sample(&mut self.rng))
334                    .expect("Shape conversion failed - this is a bug"),
335            );
336        }
337        NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
338            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
339    }
340
341    /// Creates a random array from the gamma distribution.
342    pub fn gamma(&mut self, shape: &[usize], shape_param: f64, scale: f64) -> NdArray<A> {
343        use scirs2_core::random::{Distribution, Gamma};
344        let gamma = Gamma::new(shape_param, scale)
345            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
346        let len = shape.iter().product();
347        let mut data = Vec::with_capacity(len);
348        for _ in 0..len {
349            data.push(
350                A::from(gamma.sample(&mut self.rng))
351                    .expect("Shape conversion failed - this is a bug"),
352            );
353        }
354        NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
355            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
356    }
357}
358
359impl<A: Float> Default for ArrayRng<A> {
360    fn default() -> Self {
361        Self::new()
362    }
363}
364
365/// Check if a shape represents a scalar value (empty or `[1]` shape)
366#[inline]
367#[allow(dead_code)]
368pub fn is_scalarshape(shape: &[usize]) -> bool {
369    shape.is_empty() || (shape.len() == 1 && shape[0] == 1)
370}
371
372/// Create a scalar shape (empty shape)
373#[inline]
374#[allow(dead_code)]
375pub fn scalarshape() -> Vec<usize> {
376    vec![]
377}
378
379/// Create an array from a scalar value
380#[inline]
381#[allow(dead_code)]
382pub fn from_scalar<T: Float>(value: T) -> NdArray<T> {
383    NdArray::<T>::from_elem(scirs2_core::ndarray::IxDyn(&[1]), value)
384}
385
386/// Get shape of an ndarray view
387#[inline]
388#[allow(dead_code)]
389pub fn shape_of_view<T>(view: &NdArrayView<'_, T>) -> Vec<usize> {
390    view.shape().to_vec()
391}
392
393/// Get shape of an ndarray
394#[inline]
395#[allow(dead_code)]
396pub fn shape_of<T>(array: &NdArray<T>) -> Vec<usize> {
397    array.shape().to_vec()
398}
399
400/// Get default random number generator
401#[inline]
402#[allow(dead_code)]
403pub fn get_default_rng<A: Float>() -> ArrayRng<A> {
404    ArrayRng::<A>::default()
405}
406
407/// Create a deep copy of an ndarray
408#[inline]
409#[allow(dead_code)]
410pub fn deep_copy<T: Float + Clone>(array: &NdArrayView<'_, T>) -> NdArray<T> {
411    array.to_owned()
412}
413
414/// Select elements from an array along an axis
415#[inline]
416#[allow(dead_code)]
417pub fn select<T: Float + Clone>(
418    array: &NdArrayView<'_, T>,
419    axis: scirs2_core::ndarray::Axis,
420    indices: &[usize],
421) -> NdArray<T> {
422    let mut shape = array.shape().to_vec();
423    shape[axis.index()] = indices.len();
424
425    let mut result = NdArray::<T>::zeros(scirs2_core::ndarray::IxDyn(&shape));
426
427    for (i, &idx) in indices.iter().enumerate() {
428        let slice = array.index_axis(axis, idx);
429        result.index_axis_mut(axis, i).assign(&slice);
430    }
431
432    result
433}
434
435/// Check if two shapes are compatible for broadcasting
436#[inline]
437#[allow(dead_code)]
438pub fn are_broadcast_compatible(shape1: &[usize], shape2: &[usize]) -> bool {
439    let len1 = shape1.len();
440    let len2 = shape2.len();
441    let min_len = std::cmp::min(len1, len2);
442
443    for i in 0..min_len {
444        let dim1 = shape1[len1 - 1 - i];
445        let dim2 = shape2[len2 - 1 - i];
446        if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
447            return false;
448        }
449    }
450    true
451}
452
453/// Compute the shape resulting from broadcasting two shapes together
454#[inline]
455#[allow(dead_code)]
456pub fn broadcastshape(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
457    if !are_broadcast_compatible(shape1, shape2) {
458        return None;
459    }
460
461    let len1 = shape1.len();
462    let len2 = shape2.len();
463    let result_len = std::cmp::max(len1, len2);
464    let mut result = Vec::with_capacity(result_len);
465
466    for i in 0..result_len {
467        let dim1 = if i < len1 { shape1[len1 - 1 - i] } else { 1 };
468        let dim2 = if i < len2 { shape2[len2 - 1 - i] } else { 1 };
469        result.push(std::cmp::max(dim1, dim2));
470    }
471
472    result.reverse();
473    Some(result)
474}
475
476/// Array generation functions
477pub mod array_gen {
478    use super::*;
479
480    /// Creates a zero array in the specified shape.
481    #[inline]
482    pub fn zeros<T: Float>(shape: &[usize]) -> NdArray<T> {
483        NdArray::<T>::zeros(shape)
484    }
485
486    /// Creates a one array in the specified shape.
487    #[inline]
488    pub fn ones<T: Float>(shape: &[usize]) -> NdArray<T> {
489        NdArray::<T>::ones(shape)
490    }
491
492    /// Creates a 2D identity matrix of the specified size.
493    #[inline]
494    pub fn eye<T: Float>(n: usize) -> NdArray<T> {
495        let mut result = NdArray::<T>::zeros(scirs2_core::ndarray::IxDyn(&[n, n]));
496        for i in 0..n {
497            result[[i, i]] = T::one();
498        }
499        result
500    }
501
502    /// Creates a constant array in the specified shape.
503    #[inline]
504    pub fn constant<T: Float>(value: T, shape: &[usize]) -> NdArray<T> {
505        NdArray::<T>::from_elem(shape, value)
506    }
507
508    /// Generates a random array in the specified shape with values between 0 and 1.
509    pub fn random<T: Float>(shape: &[usize]) -> NdArray<T> {
510        let mut rng = ArrayRng::<T>::default();
511        rng.random(shape)
512    }
513
514    /// Generates a random normal array in the specified shape.
515    pub fn randn<T: Float>(shape: &[usize]) -> NdArray<T> {
516        let mut rng = ArrayRng::<T>::default();
517        rng.normal(shape, 0.0, 1.0)
518    }
519
520    /// Creates a Glorot/Xavier uniform initialized array in the specified shape.
521    pub fn glorot_uniform<T: Float>(shape: &[usize]) -> NdArray<T> {
522        let mut rng = ArrayRng::<T>::default();
523        rng.glorot_uniform(shape)
524    }
525
526    /// Creates a Glorot/Xavier normal initialized array in the specified shape.
527    pub fn glorot_normal<T: Float>(shape: &[usize]) -> NdArray<T> {
528        let mut rng = ArrayRng::<T>::default();
529        rng.glorot_normal(shape)
530    }
531
532    /// Creates a He/Kaiming uniform initialized array in the specified shape.
533    pub fn he_uniform<T: Float>(shape: &[usize]) -> NdArray<T> {
534        let mut rng = ArrayRng::<T>::default();
535        rng.he_uniform(shape)
536    }
537
538    /// Creates a He/Kaiming normal initialized array in the specified shape.
539    pub fn he_normal<T: Float>(shape: &[usize]) -> NdArray<T> {
540        let mut rng = ArrayRng::<T>::default();
541        rng.he_normal(shape)
542    }
543
544    /// Creates an array with a linearly spaced sequence from start to end.
545    pub fn linspace<T: Float>(start: T, end: T, num: usize) -> NdArray<T> {
546        if num <= 1 {
547            return if num == 0 {
548                NdArray::<T>::zeros(scirs2_core::ndarray::IxDyn(&[0]))
549            } else {
550                NdArray::<T>::from_elem(scirs2_core::ndarray::IxDyn(&[1]), start)
551            };
552        }
553
554        let step = (end - start) / T::from(num - 1).unwrap_or_else(|| T::one());
555        let mut data = Vec::with_capacity(num);
556
557        for i in 0..num {
558            data.push(start + step * T::from(i).unwrap_or_else(|| T::zero()));
559        }
560
561        NdArray::<T>::from_shape_vec(scirs2_core::ndarray::IxDyn(&[num]), data)
562            .expect("Shape conversion failed - this is a bug")
563    }
564
565    /// Creates an array of evenly spaced values within a given interval.
566    pub fn arange<T: Float>(start: T, end: T, step: T) -> NdArray<T> {
567        let size = ((end - start) / step).to_f64().unwrap_or(0.0).ceil() as usize;
568        let mut data = Vec::with_capacity(size);
569
570        let mut current = start;
571        while current < end {
572            data.push(current);
573            current += step;
574        }
575
576        NdArray::<T>::from_shape_vec(scirs2_core::ndarray::IxDyn(&[data.len()]), data)
577            .expect("Shape conversion failed - this is a bug")
578    }
579}