scirs2_core/ndarray/
mod.rs

1//! Complete ndarray re-export for SciRS2 ecosystem
2//!
3//! This module provides a single, unified access point for ALL ndarray functionality,
4//! ensuring SciRS2 POLICY compliance across the entire ecosystem.
5//!
6//! ## Design Philosophy
7//!
8//! 1. **Complete Feature Parity**: All ndarray functionality available through scirs2-core
9//! 2. **Zero Breaking Changes**: Existing ndarray_ext continues to work
10//! 3. **Policy Compliance**: No need for direct ndarray imports anywhere
11//! 4. **Single Source of Truth**: One place for all array operations
12//!
13//! ## Usage
14//!
15//! ```rust
16//! // Instead of:
17//! use ndarray::{Array, array, s, Axis};  // ❌ POLICY violation
18//!
19//! // Use:
20//! use scirs2_core::ndarray::*;  // ✅ POLICY compliant
21//!
22//! let arr = array![[1, 2], [3, 4]];
23//! let slice = arr.slice(s![.., 0]);
24//! ```
25
26// ========================================
27// COMPLETE NDARRAY RE-EXPORT
28// ========================================
29
30// Complete ndarray 0.17 re-export (no version switching needed anymore)
31// Use ::ndarray to refer to the external crate (not this module)
32pub use ::ndarray::*;
33
34// Note: All macros (array!, s!, azip!, etc.) are already included via `pub use ::ndarray::*;`
35
36// ========================================
37// NDARRAY-RELATED CRATE RE-EXPORTS
38// ========================================
39
40#[cfg(feature = "random")]
41pub use ndarray_rand::{rand_distr as distributions, RandomExt, SamplingStrategy};
42
43// Note: ndarray_rand is compatible with both ndarray 0.16 and 0.17
44
45#[cfg(feature = "linalg")]
46pub use ndarray_linalg;
47
48#[cfg(feature = "array_stats")]
49pub use ndarray_stats::{
50    errors as stats_errors, interpolate, CorrelationExt, DeviationExt, MaybeNan, QuantileExt,
51    Sort1dExt, SummaryStatisticsExt,
52};
53
54#[cfg(feature = "array_io")]
55pub use ndarray_npy::{
56    NpzReader, NpzWriter, ReadNpyExt, ReadNpzError, ViewMutNpyExt, ViewNpyExt, WriteNpyError,
57    WriteNpyExt,
58};
59
60// ========================================
61// ENHANCED FUNCTIONALITY
62// ========================================
63
64/// Additional utilities for SciRS2 ecosystem
65pub mod utils {
66    use super::*;
67
68    /// Create an identity matrix
69    pub fn eye<A>(n: usize) -> Array2<A>
70    where
71        A: Clone + num_traits::Zero + num_traits::One,
72    {
73        let mut arr = Array2::zeros((n, n));
74        for i in 0..n {
75            arr[[i, i]] = A::one();
76        }
77        arr
78    }
79
80    /// Create a diagonal matrix from a vector
81    pub fn diag<A>(v: &Array1<A>) -> Array2<A>
82    where
83        A: Clone + num_traits::Zero,
84    {
85        let n = v.len();
86        let mut arr = Array2::zeros((n, n));
87        for i in 0..n {
88            arr[[i, i]] = v[i].clone();
89        }
90        arr
91    }
92
93    /// Check if arrays are approximately equal
94    pub fn allclose<A, D>(
95        a: &ArrayBase<impl Data<Elem = A>, D>,
96        b: &ArrayBase<impl Data<Elem = A>, D>,
97        rtol: A,
98        atol: A,
99    ) -> bool
100    where
101        A: PartialOrd
102            + std::ops::Sub<Output = A>
103            + std::ops::Mul<Output = A>
104            + std::ops::Add<Output = A>
105            + Clone,
106        D: Dimension,
107    {
108        if a.shape() != b.shape() {
109            return false;
110        }
111
112        a.iter().zip(b.iter()).all(|(a_val, b_val)| {
113            let diff = if a_val > b_val {
114                a_val.clone() - b_val.clone()
115            } else {
116                b_val.clone() - a_val.clone()
117            };
118
119            let threshold = atol.clone()
120                + rtol.clone()
121                    * (if a_val > b_val {
122                        a_val.clone()
123                    } else {
124                        b_val.clone()
125                    });
126
127            diff <= threshold
128        })
129    }
130
131    /// Concatenate arrays along an axis
132    pub fn concatenate<A, D>(
133        axis: Axis,
134        arrays: &[ArrayView<A, D>],
135    ) -> Result<Array<A, D>, ShapeError>
136    where
137        A: Clone,
138        D: Dimension + RemoveAxis,
139    {
140        ndarray::concatenate(axis, arrays)
141    }
142
143    /// Stack arrays along a new axis
144    pub fn stack<A, D>(
145        axis: Axis,
146        arrays: &[ArrayView<A, D>],
147    ) -> Result<Array<A, D::Larger>, ShapeError>
148    where
149        A: Clone,
150        D: Dimension,
151        D::Larger: RemoveAxis,
152    {
153        ndarray::stack(axis, arrays)
154    }
155}
156
157// ========================================
158// COMPATIBILITY LAYER
159// ========================================
160
161/// Compatibility module for smooth migration from fragmented imports
162/// and ndarray version changes (SciRS2 POLICY compliance)
163pub mod compat {
164    pub use super::*;
165    use crate::numeric::{Float, FromPrimitive};
166
167    /// Alias for commonly used types to match existing usage patterns
168    pub type DynArray<T> = ArrayD<T>;
169    pub type Matrix<T> = Array2<T>;
170    pub type Vector<T> = Array1<T>;
171    pub type Tensor3<T> = Array3<T>;
172    pub type Tensor4<T> = Array4<T>;
173
174    /// Compatibility extensions for ndarray statistical operations
175    ///
176    /// This trait provides stable statistical operation APIs that remain consistent
177    /// across ndarray version updates, implementing the SciRS2 POLICY principle
178    /// of isolating external dependency changes to scirs2-core only.
179    ///
180    /// ## Rationale
181    ///
182    /// ndarray's statistical methods have changed across versions:
183    /// - v0.16: `.mean()` returns `Option<T>`
184    /// - v0.17: `.mean()` returns `T` directly (may be NaN for invalid operations)
185    ///
186    /// This trait provides a consistent API regardless of the underlying ndarray version.
187    ///
188    /// ## Example
189    ///
190    /// ```rust
191    /// use scirs2_core::ndarray::{Array1, compat::ArrayStatCompat};
192    ///
193    /// let data = Array1::from(vec![1.0, 2.0, 3.0]);
194    /// let mean = data.mean_or(0.0);  // Stable API across ndarray versions
195    /// ```
196    pub trait ArrayStatCompat<T> {
197        /// Compute the mean of the array, returning a default value if computation fails
198        ///
199        /// This method abstracts over ndarray version differences:
200        /// - For ndarray 0.16: Unwraps the Option, using default if None
201        /// - For ndarray 0.17+: Returns the value, using default if NaN
202        fn mean_or(&self, default: T) -> T;
203
204        /// Compute the variance with optional default
205        fn var_or(&self, ddof: T, default: T) -> T;
206
207        /// Compute the standard deviation with optional default
208        fn std_or(&self, ddof: T, default: T) -> T;
209    }
210
211    impl<T, S, D> ArrayStatCompat<T> for ArrayBase<S, D>
212    where
213        T: Float + FromPrimitive,
214        S: Data<Elem = T>,
215        D: Dimension,
216    {
217        fn mean_or(&self, default: T) -> T {
218            // ndarray returns Option<T> in both 0.16 and 0.17
219            self.mean().unwrap_or(default)
220        }
221
222        fn var_or(&self, ddof: T, default: T) -> T {
223            // ndarray returns T directly (may be NaN for invalid inputs)
224            let v = self.var(ddof);
225            if v.is_nan() {
226                default
227            } else {
228                v
229            }
230        }
231
232        fn std_or(&self, ddof: T, default: T) -> T {
233            // ndarray returns T directly (may be NaN for invalid inputs)
234            let s = self.std(ddof);
235            if s.is_nan() {
236                default
237            } else {
238                s
239            }
240        }
241    }
242
243    /// Re-export from ndarray_ext for backward compatibility
244    pub use crate::ndarray_ext::{
245        broadcast_1d_to_2d,
246        broadcast_apply,
247        fancy_index_2d,
248        // Keep existing extended functionality
249        indexing,
250        is_broadcast_compatible,
251        manipulation,
252        mask_select,
253        matrix,
254        reshape_2d,
255        split_2d,
256        stack_2d,
257        stats,
258        take_2d,
259        transpose_2d,
260        where_condition,
261    };
262}
263
264// ========================================
265// PRELUDE MODULE
266// ========================================
267
268/// Prelude module with most commonly used items
269pub mod prelude {
270    pub use super::{
271        arr1,
272        arr2,
273        // Essential macros
274        array,
275        azip,
276        // Utilities
277        concatenate,
278        s,
279        stack,
280
281        stack as stack_fn,
282        // Essential types
283        Array,
284        Array0,
285        Array1,
286        Array2,
287        Array3,
288        ArrayD,
289        ArrayView,
290        ArrayView1,
291        ArrayView2,
292        ArrayViewMut,
293
294        // Common operations
295        Axis,
296        // Essential traits
297        Dimension,
298        Ix1,
299        Ix2,
300        Ix3,
301        IxDyn,
302        ScalarOperand,
303        ShapeBuilder,
304
305        Zip,
306    };
307
308    #[cfg(feature = "random")]
309    pub use super::RandomExt;
310
311    // Useful type aliases
312    pub type Matrix<T> = super::Array2<T>;
313    pub type Vector<T> = super::Array1<T>;
314}
315
316// ========================================
317// EXAMPLES MODULE
318// ========================================
319
320#[cfg(test)]
321pub mod examples {
322    //! Examples demonstrating unified ndarray access through scirs2-core
323
324    use super::*;
325
326    /// Example: Using all essential ndarray features through scirs2-core
327    ///
328    /// ```
329    /// use scirs2_core::ndarray::*;
330    ///
331    /// // Create arrays using the array! macro
332    /// let a = array![[1, 2, 3], [4, 5, 6]];
333    ///
334    /// // Use the s! macro for slicing
335    /// let row = a.slice(s![0, ..]);
336    /// let col = a.slice(s![.., 1]);
337    ///
338    /// // Use Axis for operations
339    /// let sum_axis0 = a.sum_axis(Axis(0));
340    /// let mean_axis1 = a.mean_axis(Axis(1));
341    ///
342    /// // Stack and concatenate
343    /// let b = array![[7, 8, 9], [10, 11, 12]];
344    /// let stacked = stack![Axis(0), a, b];
345    ///
346    /// // Views and iteration
347    /// for row in a.axis_iter(Axis(0)) {
348    ///     println!("Row: {:?}", row);
349    /// }
350    /// ```
351    #[test]
352    fn test_complete_functionality() {
353        // Array creation
354        let a = array![[1., 2.], [3., 4.]];
355        assert_eq!(a.shape(), &[2, 2]);
356
357        // Slicing with s! macro
358        let slice = a.slice(s![.., 0]);
359        assert_eq!(slice.len(), 2);
360
361        // Mathematical operations
362        let b = &a + &a;
363        assert_eq!(b[[0, 0]], 2.);
364
365        // Axis operations
366        let sum = a.sum_axis(Axis(0));
367        assert_eq!(sum.len(), 2);
368
369        // Broadcasting
370        let c = array![1., 2.];
371        let d = &a + &c;
372        assert_eq!(d[[0, 0]], 2.);
373    }
374}
375
376// ========================================
377// MIGRATION GUIDE
378// ========================================
379
380pub mod migration_guide {
381    //! # Migration Guide: From Fragmented to Unified ndarray Access
382    //!
383    //! ## Before (Fragmented, Policy-Violating)
384    //!
385    //! ```rust,ignore
386    //! // Different files used different imports
387    //! use scirs2_autograd::ndarray::{Array1, array};
388    //! use scirs2_core::ndarray_ext::{ArrayView};
389    //! use ndarray::{s!, Axis};  // POLICY VIOLATION!
390    //! ```
391    //!
392    //! ## After (Unified, Policy-Compliant)
393    //!
394    //! ```rust,ignore
395    //! // Single, consistent import
396    //! use scirs2_core::ndarray::*;
397    //!
398    //! // Everything works:
399    //! let arr = array![[1, 2], [3, 4]];
400    //! let slice = arr.slice(s![.., 0]);
401    //! let view: ArrayView<_, _> = arr.view();
402    //! let sum = arr.sum_axis(Axis(0));
403    //! ```
404    //!
405    //! ## Benefits
406    //!
407    //! 1. **Single Import Path**: No more confusion about where to import from
408    //! 2. **Complete Functionality**: All ndarray features available
409    //! 3. **Policy Compliance**: No direct ndarray imports needed
410    //! 4. **Future-Proof**: Centralized control over array functionality
411    //!
412    //! ## Quick Reference
413    //!
414    //! | Old Import | New Import |
415    //! |------------|------------|
416    //! | `use ndarray::{Array, array}` | `use scirs2_core::ndarray::{Array, array}` |
417    //! | `use scirs2_autograd::ndarray::*` | `use scirs2_core::ndarray::*` |
418    //! | `use scirs2_core::ndarray_ext::*` | `use scirs2_core::ndarray::*` |
419    //! | `use ndarray::{s!, Axis}` | `use scirs2_core::ndarray::{s, Axis}` |
420}
421
422// Re-export compatibility traits for easy access
423pub use compat::ArrayStatCompat;
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428
429    #[test]
430    fn test_array_macro_available() {
431        let arr = array![[1, 2], [3, 4]];
432        assert_eq!(arr.shape(), &[2, 2]);
433        assert_eq!(arr[[0, 0]], 1);
434    }
435
436    #[test]
437    fn test_s_macro_available() {
438        let arr = array![[1, 2, 3], [4, 5, 6]];
439        let slice = arr.slice(s![.., 1..]);
440        assert_eq!(slice.shape(), &[2, 2]);
441    }
442
443    #[test]
444    fn test_axis_operations() {
445        let arr = array![[1., 2.], [3., 4.]];
446        let sum = arr.sum_axis(Axis(0));
447        assert_eq!(sum, array![4., 6.]);
448    }
449
450    #[test]
451    fn test_views_and_iteration() {
452        let mut arr = array![[1, 2], [3, 4]];
453
454        // Test immutable view first
455        {
456            let view: ArrayView<_, _> = arr.view();
457            for val in view.iter() {
458                assert!(*val > 0);
459            }
460        }
461
462        // Test mutable view after immutable view is dropped
463        {
464            let mut view_mut: ArrayViewMut<_, _> = arr.view_mut();
465            for val in view_mut.iter_mut() {
466                *val *= 2;
467            }
468        }
469
470        assert_eq!(arr[[0, 0]], 2);
471    }
472
473    #[test]
474    fn test_concatenate_and_stack() {
475        let a = array![[1, 2], [3, 4]];
476        let b = array![[5, 6], [7, 8]];
477
478        // Concatenate along axis 0
479        let concat = concatenate(Axis(0), &[a.view(), b.view()]).unwrap();
480        assert_eq!(concat.shape(), &[4, 2]);
481
482        // Stack along new axis
483        let stacked = crate::ndarray::stack(Axis(0), &[a.view(), b.view()]).unwrap();
484        assert_eq!(stacked.shape(), &[2, 2, 2]);
485    }
486
487    #[test]
488    fn test_zip_operations() {
489        let a = array![1, 2, 3];
490        let b = array![4, 5, 6];
491        let mut c = array![0, 0, 0];
492
493        azip!((a in &a, b in &b, c in &mut c) {
494            *c = a + b;
495        });
496
497        assert_eq!(c, array![5, 7, 9]);
498    }
499
500    #[test]
501    fn test_array_stat_compat() {
502        use compat::ArrayStatCompat;
503
504        // Test mean_or
505        let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
506        assert_eq!(data.mean_or(0.0), 3.0);
507
508        let empty = Array1::<f64>::from(vec![]);
509        assert_eq!(empty.mean_or(0.0), 0.0);
510
511        // Test var_or
512        let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
513        let var = data.var_or(1.0, 0.0);
514        assert!(var > 0.0);
515
516        // Test std_or
517        let std = data.std_or(1.0, 0.0);
518        assert!(std > 0.0);
519    }
520}