scirs2_core/ndarray/mod.rs
1//! Complete ndarray re-export for SciRS2 ecosystem
2//!
3//! This module provides a single, unified access point for ALL ndarray functionality,
4//! ensuring SciRS2 POLICY compliance across the entire ecosystem.
5//!
6//! ## Design Philosophy
7//!
8//! 1. **Complete Feature Parity**: All ndarray functionality available through scirs2-core
9//! 2. **Zero Breaking Changes**: Existing ndarray_ext continues to work
10//! 3. **Policy Compliance**: No need for direct ndarray imports anywhere
11//! 4. **Single Source of Truth**: One place for all array operations
12//!
13//! ## Usage
14//!
15//! ```rust
16//! // Instead of:
17//! use ndarray::{Array, array, s, Axis}; // ❌ POLICY violation
18//!
19//! // Use:
20//! use scirs2_core::ndarray::*; // ✅ POLICY compliant
21//!
22//! let arr = array![[1, 2], [3, 4]];
23//! let slice = arr.slice(s![.., 0]);
24//! ```
25
26// ========================================
27// COMPLETE NDARRAY RE-EXPORT
28// ========================================
29
30// Complete ndarray 0.17 re-export (no version switching needed anymore)
31// Use ::ndarray to refer to the external crate (not this module)
32pub use ::ndarray::*;
33
34// Note: All macros (array!, s!, azip!, etc.) are already included via `pub use ::ndarray::*;`
35
36// ========================================
37// NDARRAY-RELATED CRATE RE-EXPORTS
38// ========================================
39
40#[cfg(feature = "random")]
41pub use ndarray_rand::{rand_distr as distributions, RandomExt, SamplingStrategy};
42
43// Note: ndarray_rand is compatible with both ndarray 0.16 and 0.17
44
45#[cfg(feature = "linalg")]
46pub use ndarray_linalg;
47
48#[cfg(feature = "array_stats")]
49pub use ndarray_stats::{
50 errors as stats_errors, interpolate, CorrelationExt, DeviationExt, MaybeNan, QuantileExt,
51 Sort1dExt, SummaryStatisticsExt,
52};
53
54#[cfg(feature = "array_io")]
55pub use ndarray_npy::{
56 NpzReader, NpzWriter, ReadNpyExt, ReadNpzError, ViewMutNpyExt, ViewNpyExt, WriteNpyError,
57 WriteNpyExt,
58};
59
60// ========================================
61// ENHANCED FUNCTIONALITY
62// ========================================
63
64/// Additional utilities for SciRS2 ecosystem
65pub mod utils {
66 use super::*;
67
68 /// Create an identity matrix
69 pub fn eye<A>(n: usize) -> Array2<A>
70 where
71 A: Clone + num_traits::Zero + num_traits::One,
72 {
73 let mut arr = Array2::zeros((n, n));
74 for i in 0..n {
75 arr[[i, i]] = A::one();
76 }
77 arr
78 }
79
80 /// Create a diagonal matrix from a vector
81 pub fn diag<A>(v: &Array1<A>) -> Array2<A>
82 where
83 A: Clone + num_traits::Zero,
84 {
85 let n = v.len();
86 let mut arr = Array2::zeros((n, n));
87 for i in 0..n {
88 arr[[i, i]] = v[i].clone();
89 }
90 arr
91 }
92
93 /// Check if arrays are approximately equal
94 pub fn allclose<A, D>(
95 a: &ArrayBase<impl Data<Elem = A>, D>,
96 b: &ArrayBase<impl Data<Elem = A>, D>,
97 rtol: A,
98 atol: A,
99 ) -> bool
100 where
101 A: PartialOrd
102 + std::ops::Sub<Output = A>
103 + std::ops::Mul<Output = A>
104 + std::ops::Add<Output = A>
105 + Clone,
106 D: Dimension,
107 {
108 if a.shape() != b.shape() {
109 return false;
110 }
111
112 a.iter().zip(b.iter()).all(|(a_val, b_val)| {
113 let diff = if a_val > b_val {
114 a_val.clone() - b_val.clone()
115 } else {
116 b_val.clone() - a_val.clone()
117 };
118
119 let threshold = atol.clone()
120 + rtol.clone()
121 * (if a_val > b_val {
122 a_val.clone()
123 } else {
124 b_val.clone()
125 });
126
127 diff <= threshold
128 })
129 }
130
131 /// Concatenate arrays along an axis
132 pub fn concatenate<A, D>(
133 axis: Axis,
134 arrays: &[ArrayView<A, D>],
135 ) -> Result<Array<A, D>, ShapeError>
136 where
137 A: Clone,
138 D: Dimension + RemoveAxis,
139 {
140 ndarray::concatenate(axis, arrays)
141 }
142
143 /// Stack arrays along a new axis
144 pub fn stack<A, D>(
145 axis: Axis,
146 arrays: &[ArrayView<A, D>],
147 ) -> Result<Array<A, D::Larger>, ShapeError>
148 where
149 A: Clone,
150 D: Dimension,
151 D::Larger: RemoveAxis,
152 {
153 ndarray::stack(axis, arrays)
154 }
155}
156
157// ========================================
158// COMPATIBILITY LAYER
159// ========================================
160
161/// Compatibility module for smooth migration from fragmented imports
162/// and ndarray version changes (SciRS2 POLICY compliance)
163pub mod compat {
164 pub use super::*;
165 use crate::numeric::{Float, FromPrimitive};
166
167 /// Alias for commonly used types to match existing usage patterns
168 pub type DynArray<T> = ArrayD<T>;
169 pub type Matrix<T> = Array2<T>;
170 pub type Vector<T> = Array1<T>;
171 pub type Tensor3<T> = Array3<T>;
172 pub type Tensor4<T> = Array4<T>;
173
174 /// Compatibility extensions for ndarray statistical operations
175 ///
176 /// This trait provides stable statistical operation APIs that remain consistent
177 /// across ndarray version updates, implementing the SciRS2 POLICY principle
178 /// of isolating external dependency changes to scirs2-core only.
179 ///
180 /// ## Rationale
181 ///
182 /// ndarray's statistical methods have changed across versions:
183 /// - v0.16: `.mean()` returns `Option<T>`
184 /// - v0.17: `.mean()` returns `T` directly (may be NaN for invalid operations)
185 ///
186 /// This trait provides a consistent API regardless of the underlying ndarray version.
187 ///
188 /// ## Example
189 ///
190 /// ```rust
191 /// use scirs2_core::ndarray::{Array1, compat::ArrayStatCompat};
192 ///
193 /// let data = Array1::from(vec![1.0, 2.0, 3.0]);
194 /// let mean = data.mean_or(0.0); // Stable API across ndarray versions
195 /// ```
196 pub trait ArrayStatCompat<T> {
197 /// Compute the mean of the array, returning a default value if computation fails
198 ///
199 /// This method abstracts over ndarray version differences:
200 /// - For ndarray 0.16: Unwraps the Option, using default if None
201 /// - For ndarray 0.17+: Returns the value, using default if NaN
202 fn mean_or(&self, default: T) -> T;
203
204 /// Compute the variance with optional default
205 fn var_or(&self, ddof: T, default: T) -> T;
206
207 /// Compute the standard deviation with optional default
208 fn std_or(&self, ddof: T, default: T) -> T;
209 }
210
211 impl<T, S, D> ArrayStatCompat<T> for ArrayBase<S, D>
212 where
213 T: Float + FromPrimitive,
214 S: Data<Elem = T>,
215 D: Dimension,
216 {
217 fn mean_or(&self, default: T) -> T {
218 // ndarray returns Option<T> in both 0.16 and 0.17
219 self.mean().unwrap_or(default)
220 }
221
222 fn var_or(&self, ddof: T, default: T) -> T {
223 // ndarray returns T directly (may be NaN for invalid inputs)
224 let v = self.var(ddof);
225 if v.is_nan() {
226 default
227 } else {
228 v
229 }
230 }
231
232 fn std_or(&self, ddof: T, default: T) -> T {
233 // ndarray returns T directly (may be NaN for invalid inputs)
234 let s = self.std(ddof);
235 if s.is_nan() {
236 default
237 } else {
238 s
239 }
240 }
241 }
242
243 /// Re-export from ndarray_ext for backward compatibility
244 pub use crate::ndarray_ext::{
245 broadcast_1d_to_2d,
246 broadcast_apply,
247 fancy_index_2d,
248 // Keep existing extended functionality
249 indexing,
250 is_broadcast_compatible,
251 manipulation,
252 mask_select,
253 matrix,
254 reshape_2d,
255 split_2d,
256 stack_2d,
257 stats,
258 take_2d,
259 transpose_2d,
260 where_condition,
261 };
262}
263
264// ========================================
265// PRELUDE MODULE
266// ========================================
267
268/// Prelude module with most commonly used items
269pub mod prelude {
270 pub use super::{
271 arr1,
272 arr2,
273 // Essential macros
274 array,
275 azip,
276 // Utilities
277 concatenate,
278 s,
279 stack,
280
281 stack as stack_fn,
282 // Essential types
283 Array,
284 Array0,
285 Array1,
286 Array2,
287 Array3,
288 ArrayD,
289 ArrayView,
290 ArrayView1,
291 ArrayView2,
292 ArrayViewMut,
293
294 // Common operations
295 Axis,
296 // Essential traits
297 Dimension,
298 Ix1,
299 Ix2,
300 Ix3,
301 IxDyn,
302 ScalarOperand,
303 ShapeBuilder,
304
305 Zip,
306 };
307
308 #[cfg(feature = "random")]
309 pub use super::RandomExt;
310
311 // Useful type aliases
312 pub type Matrix<T> = super::Array2<T>;
313 pub type Vector<T> = super::Array1<T>;
314}
315
316// ========================================
317// EXAMPLES MODULE
318// ========================================
319
320#[cfg(test)]
321pub mod examples {
322 //! Examples demonstrating unified ndarray access through scirs2-core
323
324 use super::*;
325
326 /// Example: Using all essential ndarray features through scirs2-core
327 ///
328 /// ```
329 /// use scirs2_core::ndarray::*;
330 ///
331 /// // Create arrays using the array! macro
332 /// let a = array![[1, 2, 3], [4, 5, 6]];
333 ///
334 /// // Use the s! macro for slicing
335 /// let row = a.slice(s![0, ..]);
336 /// let col = a.slice(s![.., 1]);
337 ///
338 /// // Use Axis for operations
339 /// let sum_axis0 = a.sum_axis(Axis(0));
340 /// let mean_axis1 = a.mean_axis(Axis(1));
341 ///
342 /// // Stack and concatenate
343 /// let b = array![[7, 8, 9], [10, 11, 12]];
344 /// let stacked = stack![Axis(0), a, b];
345 ///
346 /// // Views and iteration
347 /// for row in a.axis_iter(Axis(0)) {
348 /// println!("Row: {:?}", row);
349 /// }
350 /// ```
351 #[test]
352 fn test_complete_functionality() {
353 // Array creation
354 let a = array![[1., 2.], [3., 4.]];
355 assert_eq!(a.shape(), &[2, 2]);
356
357 // Slicing with s! macro
358 let slice = a.slice(s![.., 0]);
359 assert_eq!(slice.len(), 2);
360
361 // Mathematical operations
362 let b = &a + &a;
363 assert_eq!(b[[0, 0]], 2.);
364
365 // Axis operations
366 let sum = a.sum_axis(Axis(0));
367 assert_eq!(sum.len(), 2);
368
369 // Broadcasting
370 let c = array![1., 2.];
371 let d = &a + &c;
372 assert_eq!(d[[0, 0]], 2.);
373 }
374}
375
376// ========================================
377// MIGRATION GUIDE
378// ========================================
379
380pub mod migration_guide {
381 //! # Migration Guide: From Fragmented to Unified ndarray Access
382 //!
383 //! ## Before (Fragmented, Policy-Violating)
384 //!
385 //! ```rust,ignore
386 //! // Different files used different imports
387 //! use scirs2_autograd::ndarray::{Array1, array};
388 //! use scirs2_core::ndarray_ext::{ArrayView};
389 //! use ndarray::{s!, Axis}; // POLICY VIOLATION!
390 //! ```
391 //!
392 //! ## After (Unified, Policy-Compliant)
393 //!
394 //! ```rust,ignore
395 //! // Single, consistent import
396 //! use scirs2_core::ndarray::*;
397 //!
398 //! // Everything works:
399 //! let arr = array![[1, 2], [3, 4]];
400 //! let slice = arr.slice(s![.., 0]);
401 //! let view: ArrayView<_, _> = arr.view();
402 //! let sum = arr.sum_axis(Axis(0));
403 //! ```
404 //!
405 //! ## Benefits
406 //!
407 //! 1. **Single Import Path**: No more confusion about where to import from
408 //! 2. **Complete Functionality**: All ndarray features available
409 //! 3. **Policy Compliance**: No direct ndarray imports needed
410 //! 4. **Future-Proof**: Centralized control over array functionality
411 //!
412 //! ## Quick Reference
413 //!
414 //! | Old Import | New Import |
415 //! |------------|------------|
416 //! | `use ndarray::{Array, array}` | `use scirs2_core::ndarray::{Array, array}` |
417 //! | `use scirs2_autograd::ndarray::*` | `use scirs2_core::ndarray::*` |
418 //! | `use scirs2_core::ndarray_ext::*` | `use scirs2_core::ndarray::*` |
419 //! | `use ndarray::{s!, Axis}` | `use scirs2_core::ndarray::{s, Axis}` |
420}
421
422// Re-export compatibility traits for easy access
423pub use compat::ArrayStatCompat;
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428
429 #[test]
430 fn test_array_macro_available() {
431 let arr = array![[1, 2], [3, 4]];
432 assert_eq!(arr.shape(), &[2, 2]);
433 assert_eq!(arr[[0, 0]], 1);
434 }
435
436 #[test]
437 fn test_s_macro_available() {
438 let arr = array![[1, 2, 3], [4, 5, 6]];
439 let slice = arr.slice(s![.., 1..]);
440 assert_eq!(slice.shape(), &[2, 2]);
441 }
442
443 #[test]
444 fn test_axis_operations() {
445 let arr = array![[1., 2.], [3., 4.]];
446 let sum = arr.sum_axis(Axis(0));
447 assert_eq!(sum, array![4., 6.]);
448 }
449
450 #[test]
451 fn test_views_and_iteration() {
452 let mut arr = array![[1, 2], [3, 4]];
453
454 // Test immutable view first
455 {
456 let view: ArrayView<_, _> = arr.view();
457 for val in view.iter() {
458 assert!(*val > 0);
459 }
460 }
461
462 // Test mutable view after immutable view is dropped
463 {
464 let mut view_mut: ArrayViewMut<_, _> = arr.view_mut();
465 for val in view_mut.iter_mut() {
466 *val *= 2;
467 }
468 }
469
470 assert_eq!(arr[[0, 0]], 2);
471 }
472
473 #[test]
474 fn test_concatenate_and_stack() {
475 let a = array![[1, 2], [3, 4]];
476 let b = array![[5, 6], [7, 8]];
477
478 // Concatenate along axis 0
479 let concat = concatenate(Axis(0), &[a.view(), b.view()]).unwrap();
480 assert_eq!(concat.shape(), &[4, 2]);
481
482 // Stack along new axis
483 let stacked = crate::ndarray::stack(Axis(0), &[a.view(), b.view()]).unwrap();
484 assert_eq!(stacked.shape(), &[2, 2, 2]);
485 }
486
487 #[test]
488 fn test_zip_operations() {
489 let a = array![1, 2, 3];
490 let b = array![4, 5, 6];
491 let mut c = array![0, 0, 0];
492
493 azip!((a in &a, b in &b, c in &mut c) {
494 *c = a + b;
495 });
496
497 assert_eq!(c, array![5, 7, 9]);
498 }
499
500 #[test]
501 fn test_array_stat_compat() {
502 use compat::ArrayStatCompat;
503
504 // Test mean_or
505 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
506 assert_eq!(data.mean_or(0.0), 3.0);
507
508 let empty = Array1::<f64>::from(vec![]);
509 assert_eq!(empty.mean_or(0.0), 0.0);
510
511 // Test var_or
512 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
513 let var = data.var_or(1.0, 0.0);
514 assert!(var > 0.0);
515
516 // Test std_or
517 let std = data.std_or(1.0, 0.0);
518 assert!(std > 0.0);
519 }
520}