scirs2_core/ndarray/
mod.rs

1//! Complete ndarray re-export for SciRS2 ecosystem
2//!
3//! This module provides a single, unified access point for ALL ndarray functionality,
4//! ensuring SciRS2 POLICY compliance across the entire ecosystem.
5//!
6//! ## Design Philosophy
7//!
8//! 1. **Complete Feature Parity**: All ndarray functionality available through scirs2-core
9//! 2. **Zero Breaking Changes**: Existing ndarray_ext continues to work
10//! 3. **Policy Compliance**: No need for direct ndarray imports anywhere
11//! 4. **Single Source of Truth**: One place for all array operations
12//!
13//! ## Usage
14//!
15//! ```rust
16//! // Instead of:
17//! use ndarray::{Array, array, s, Axis};  // ❌ POLICY violation
18//!
19//! // Use:
20//! use scirs2_core::ndarray::*;  // ✅ POLICY compliant
21//!
22//! let arr = array![[1, 2], [3, 4]];
23//! let slice = arr.slice(s![.., 0]);
24//! ```
25
26// ========================================
27// COMPLETE NDARRAY RE-EXPORT
28// ========================================
29
30// Complete ndarray 0.17 re-export (no version switching needed anymore)
31// Use ::ndarray to refer to the external crate (not this module)
32pub use ::ndarray::*;
33
34// Note: All macros (array!, s!, azip!, etc.) are already included via `pub use ::ndarray::*;`
35
36// ========================================
37// NDARRAY-RELATED CRATE RE-EXPORTS
38// ========================================
39
40#[cfg(feature = "random")]
41pub use ndarray_rand::{rand_distr as distributions, RandomExt, SamplingStrategy};
42
43// Note: ndarray_rand is compatible with both ndarray 0.16 and 0.17
44
45// NOTE: ndarray_linalg removed - using OxiBLAS via scirs2_core::linalg module
46
47#[cfg(feature = "array_stats")]
48pub use ndarray_stats::{
49    errors as stats_errors, interpolate, CorrelationExt, DeviationExt, MaybeNan, QuantileExt,
50    Sort1dExt, SummaryStatisticsExt,
51};
52
53#[cfg(feature = "array_io")]
54pub use ndarray_npy::{
55    NpzReader, NpzWriter, ReadNpyExt, ReadNpzError, ViewMutNpyExt, ViewNpyExt, WriteNpyError,
56    WriteNpyExt,
57};
58
59// ========================================
60// ENHANCED FUNCTIONALITY
61// ========================================
62
63/// Additional utilities for SciRS2 ecosystem
64pub mod utils {
65    use super::*;
66
67    /// Create an identity matrix
68    pub fn eye<A>(n: usize) -> Array2<A>
69    where
70        A: Clone + num_traits::Zero + num_traits::One,
71    {
72        let mut arr = Array2::zeros((n, n));
73        for i in 0..n {
74            arr[[i, i]] = A::one();
75        }
76        arr
77    }
78
79    /// Create a diagonal matrix from a vector
80    pub fn diag<A>(v: &Array1<A>) -> Array2<A>
81    where
82        A: Clone + num_traits::Zero,
83    {
84        let n = v.len();
85        let mut arr = Array2::zeros((n, n));
86        for i in 0..n {
87            arr[[i, i]] = v[i].clone();
88        }
89        arr
90    }
91
92    /// Check if arrays are approximately equal
93    pub fn allclose<A, D>(
94        a: &ArrayBase<impl Data<Elem = A>, D>,
95        b: &ArrayBase<impl Data<Elem = A>, D>,
96        rtol: A,
97        atol: A,
98    ) -> bool
99    where
100        A: PartialOrd
101            + std::ops::Sub<Output = A>
102            + std::ops::Mul<Output = A>
103            + std::ops::Add<Output = A>
104            + Clone,
105        D: Dimension,
106    {
107        if a.shape() != b.shape() {
108            return false;
109        }
110
111        a.iter().zip(b.iter()).all(|(a_val, b_val)| {
112            let diff = if a_val > b_val {
113                a_val.clone() - b_val.clone()
114            } else {
115                b_val.clone() - a_val.clone()
116            };
117
118            let threshold = atol.clone()
119                + rtol.clone()
120                    * (if a_val > b_val {
121                        a_val.clone()
122                    } else {
123                        b_val.clone()
124                    });
125
126            diff <= threshold
127        })
128    }
129
130    /// Concatenate arrays along an axis
131    pub fn concatenate<A, D>(
132        axis: Axis,
133        arrays: &[ArrayView<A, D>],
134    ) -> Result<Array<A, D>, ShapeError>
135    where
136        A: Clone,
137        D: Dimension + RemoveAxis,
138    {
139        ndarray::concatenate(axis, arrays)
140    }
141
142    /// Stack arrays along a new axis
143    pub fn stack<A, D>(
144        axis: Axis,
145        arrays: &[ArrayView<A, D>],
146    ) -> Result<Array<A, D::Larger>, ShapeError>
147    where
148        A: Clone,
149        D: Dimension,
150        D::Larger: RemoveAxis,
151    {
152        ndarray::stack(axis, arrays)
153    }
154}
155
156// ========================================
157// COMPATIBILITY LAYER
158// ========================================
159
160/// Compatibility module for smooth migration from fragmented imports
161/// and ndarray version changes (SciRS2 POLICY compliance)
162pub mod compat {
163    pub use super::*;
164    use crate::numeric::{Float, FromPrimitive};
165
166    /// Alias for commonly used types to match existing usage patterns
167    pub type DynArray<T> = ArrayD<T>;
168    pub type Matrix<T> = Array2<T>;
169    pub type Vector<T> = Array1<T>;
170    pub type Tensor3<T> = Array3<T>;
171    pub type Tensor4<T> = Array4<T>;
172
173    /// Compatibility extensions for ndarray statistical operations
174    ///
175    /// This trait provides stable statistical operation APIs that remain consistent
176    /// across ndarray version updates, implementing the SciRS2 POLICY principle
177    /// of isolating external dependency changes to scirs2-core only.
178    ///
179    /// ## Rationale
180    ///
181    /// ndarray's statistical methods have changed across versions:
182    /// - v0.16: `.mean()` returns `Option<T>`
183    /// - v0.17: `.mean()` returns `T` directly (may be NaN for invalid operations)
184    ///
185    /// This trait provides a consistent API regardless of the underlying ndarray version.
186    ///
187    /// ## Example
188    ///
189    /// ```rust
190    /// use scirs2_core::ndarray::{Array1, compat::ArrayStatCompat};
191    ///
192    /// let data = Array1::from(vec![1.0, 2.0, 3.0]);
193    /// let mean = data.mean_or(0.0);  // Stable API across ndarray versions
194    /// ```
195    pub trait ArrayStatCompat<T> {
196        /// Compute the mean of the array, returning a default value if computation fails
197        ///
198        /// This method abstracts over ndarray version differences:
199        /// - For ndarray 0.16: Unwraps the Option, using default if None
200        /// - For ndarray 0.17+: Returns the value, using default if NaN
201        fn mean_or(&self, default: T) -> T;
202
203        /// Compute the variance with optional default
204        fn var_or(&self, ddof: T, default: T) -> T;
205
206        /// Compute the standard deviation with optional default
207        fn std_or(&self, ddof: T, default: T) -> T;
208    }
209
210    impl<T, S, D> ArrayStatCompat<T> for ArrayBase<S, D>
211    where
212        T: Float + FromPrimitive,
213        S: Data<Elem = T>,
214        D: Dimension,
215    {
216        fn mean_or(&self, default: T) -> T {
217            // ndarray returns Option<T> in both 0.16 and 0.17
218            self.mean().unwrap_or(default)
219        }
220
221        fn var_or(&self, ddof: T, default: T) -> T {
222            // ndarray returns T directly (may be NaN for invalid inputs)
223            let v = self.var(ddof);
224            if v.is_nan() {
225                default
226            } else {
227                v
228            }
229        }
230
231        fn std_or(&self, ddof: T, default: T) -> T {
232            // ndarray returns T directly (may be NaN for invalid inputs)
233            let s = self.std(ddof);
234            if s.is_nan() {
235                default
236            } else {
237                s
238            }
239        }
240    }
241
242    /// Re-export from ndarray_ext for backward compatibility
243    pub use crate::ndarray_ext::{
244        broadcast_1d_to_2d,
245        broadcast_apply,
246        fancy_index_2d,
247        // Keep existing extended functionality
248        indexing,
249        is_broadcast_compatible,
250        manipulation,
251        mask_select,
252        matrix,
253        reshape_2d,
254        split_2d,
255        stack_2d,
256        stats,
257        take_2d,
258        transpose_2d,
259        where_condition,
260    };
261}
262
263// ========================================
264// PRELUDE MODULE
265// ========================================
266
267/// Prelude module with most commonly used items
268pub mod prelude {
269    pub use super::{
270        arr1,
271        arr2,
272        // Essential macros
273        array,
274        azip,
275        // Utilities
276        concatenate,
277        s,
278        stack,
279
280        stack as stack_fn,
281        // Essential types
282        Array,
283        Array0,
284        Array1,
285        Array2,
286        Array3,
287        ArrayD,
288        ArrayView,
289        ArrayView1,
290        ArrayView2,
291        ArrayViewMut,
292
293        // Common operations
294        Axis,
295        // Essential traits
296        Dimension,
297        Ix1,
298        Ix2,
299        Ix3,
300        IxDyn,
301        ScalarOperand,
302        ShapeBuilder,
303
304        Zip,
305    };
306
307    #[cfg(feature = "random")]
308    pub use super::RandomExt;
309
310    // Useful type aliases
311    pub type Matrix<T> = super::Array2<T>;
312    pub type Vector<T> = super::Array1<T>;
313}
314
315// ========================================
316// EXAMPLES MODULE
317// ========================================
318
319#[cfg(test)]
320pub mod examples {
321    //! Examples demonstrating unified ndarray access through scirs2-core
322
323    use super::*;
324
325    /// Example: Using all essential ndarray features through scirs2-core
326    ///
327    /// ```
328    /// use scirs2_core::ndarray::*;
329    ///
330    /// // Create arrays using the array! macro
331    /// let a = array![[1, 2, 3], [4, 5, 6]];
332    ///
333    /// // Use the s! macro for slicing
334    /// let row = a.slice(s![0, ..]);
335    /// let col = a.slice(s![.., 1]);
336    ///
337    /// // Use Axis for operations
338    /// let sum_axis0 = a.sum_axis(Axis(0));
339    /// let mean_axis1 = a.mean_axis(Axis(1));
340    ///
341    /// // Stack and concatenate
342    /// let b = array![[7, 8, 9], [10, 11, 12]];
343    /// let stacked = stack![Axis(0), a, b];
344    ///
345    /// // Views and iteration
346    /// for row in a.axis_iter(Axis(0)) {
347    ///     println!("Row: {:?}", row);
348    /// }
349    /// ```
350    #[test]
351    fn test_complete_functionality() {
352        // Array creation
353        let a = array![[1., 2.], [3., 4.]];
354        assert_eq!(a.shape(), &[2, 2]);
355
356        // Slicing with s! macro
357        let slice = a.slice(s![.., 0]);
358        assert_eq!(slice.len(), 2);
359
360        // Mathematical operations
361        let b = &a + &a;
362        assert_eq!(b[[0, 0]], 2.);
363
364        // Axis operations
365        let sum = a.sum_axis(Axis(0));
366        assert_eq!(sum.len(), 2);
367
368        // Broadcasting
369        let c = array![1., 2.];
370        let d = &a + &c;
371        assert_eq!(d[[0, 0]], 2.);
372    }
373}
374
375// ========================================
376// MIGRATION GUIDE
377// ========================================
378
379pub mod migration_guide {
380    //! # Migration Guide: From Fragmented to Unified ndarray Access
381    //!
382    //! ## Before (Fragmented, Policy-Violating)
383    //!
384    //! ```rust,ignore
385    //! // Different files used different imports
386    //! use scirs2_autograd::ndarray::{Array1, array};
387    //! use scirs2_core::ndarray_ext::{ArrayView};
388    //! use ndarray::{s!, Axis};  // POLICY VIOLATION!
389    //! ```
390    //!
391    //! ## After (Unified, Policy-Compliant)
392    //!
393    //! ```rust,ignore
394    //! // Single, consistent import
395    //! use scirs2_core::ndarray::*;
396    //!
397    //! // Everything works:
398    //! let arr = array![[1, 2], [3, 4]];
399    //! let slice = arr.slice(s![.., 0]);
400    //! let view: ArrayView<_, _> = arr.view();
401    //! let sum = arr.sum_axis(Axis(0));
402    //! ```
403    //!
404    //! ## Benefits
405    //!
406    //! 1. **Single Import Path**: No more confusion about where to import from
407    //! 2. **Complete Functionality**: All ndarray features available
408    //! 3. **Policy Compliance**: No direct ndarray imports needed
409    //! 4. **Future-Proof**: Centralized control over array functionality
410    //!
411    //! ## Quick Reference
412    //!
413    //! | Old Import | New Import |
414    //! |------------|------------|
415    //! | `use ndarray::{Array, array}` | `use scirs2_core::ndarray::{Array, array}` |
416    //! | `use scirs2_autograd::ndarray::*` | `use scirs2_core::ndarray::*` |
417    //! | `use scirs2_core::ndarray_ext::*` | `use scirs2_core::ndarray::*` |
418    //! | `use ndarray::{s!, Axis}` | `use scirs2_core::ndarray::{s, Axis}` |
419}
420
421// Re-export compatibility traits for easy access
422pub use compat::ArrayStatCompat;
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn test_array_macro_available() {
430        let arr = array![[1, 2], [3, 4]];
431        assert_eq!(arr.shape(), &[2, 2]);
432        assert_eq!(arr[[0, 0]], 1);
433    }
434
435    #[test]
436    fn test_s_macro_available() {
437        let arr = array![[1, 2, 3], [4, 5, 6]];
438        let slice = arr.slice(s![.., 1..]);
439        assert_eq!(slice.shape(), &[2, 2]);
440    }
441
442    #[test]
443    fn test_axis_operations() {
444        let arr = array![[1., 2.], [3., 4.]];
445        let sum = arr.sum_axis(Axis(0));
446        assert_eq!(sum, array![4., 6.]);
447    }
448
449    #[test]
450    fn test_views_and_iteration() {
451        let mut arr = array![[1, 2], [3, 4]];
452
453        // Test immutable view first
454        {
455            let view: ArrayView<_, _> = arr.view();
456            for val in view.iter() {
457                assert!(*val > 0);
458            }
459        }
460
461        // Test mutable view after immutable view is dropped
462        {
463            let mut view_mut: ArrayViewMut<_, _> = arr.view_mut();
464            for val in view_mut.iter_mut() {
465                *val *= 2;
466            }
467        }
468
469        assert_eq!(arr[[0, 0]], 2);
470    }
471
472    #[test]
473    fn test_concatenate_and_stack() {
474        let a = array![[1, 2], [3, 4]];
475        let b = array![[5, 6], [7, 8]];
476
477        // Concatenate along axis 0
478        let concat = concatenate(Axis(0), &[a.view(), b.view()]).expect("Operation failed");
479        assert_eq!(concat.shape(), &[4, 2]);
480
481        // Stack along new axis
482        let stacked =
483            crate::ndarray::stack(Axis(0), &[a.view(), b.view()]).expect("Operation failed");
484        assert_eq!(stacked.shape(), &[2, 2, 2]);
485    }
486
487    #[test]
488    fn test_zip_operations() {
489        let a = array![1, 2, 3];
490        let b = array![4, 5, 6];
491        let mut c = array![0, 0, 0];
492
493        azip!((a in &a, b in &b, c in &mut c) {
494            *c = a + b;
495        });
496
497        assert_eq!(c, array![5, 7, 9]);
498    }
499
500    #[test]
501    fn test_array_stat_compat() {
502        use compat::ArrayStatCompat;
503
504        // Test mean_or
505        let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
506        assert_eq!(data.mean_or(0.0), 3.0);
507
508        let empty = Array1::<f64>::from(vec![]);
509        assert_eq!(empty.mean_or(0.0), 0.0);
510
511        // Test var_or
512        let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
513        let var = data.var_or(1.0, 0.0);
514        assert!(var > 0.0);
515
516        // Test std_or
517        let std = data.std_or(1.0, 0.0);
518        assert!(std > 0.0);
519    }
520}