scirs2_core/ndarray/mod.rs
1//! Complete ndarray re-export for SciRS2 ecosystem
2//!
3//! This module provides a single, unified access point for ALL ndarray functionality,
4//! ensuring SciRS2 POLICY compliance across the entire ecosystem.
5//!
6//! ## Design Philosophy
7//!
8//! 1. **Complete Feature Parity**: All ndarray functionality available through scirs2-core
9//! 2. **Zero Breaking Changes**: Existing ndarray_ext continues to work
10//! 3. **Policy Compliance**: No need for direct ndarray imports anywhere
11//! 4. **Single Source of Truth**: One place for all array operations
12//!
13//! ## Usage
14//!
15//! ```rust
16//! // Instead of:
17//! use ndarray::{Array, array, s, Axis}; // ❌ POLICY violation
18//!
19//! // Use:
20//! use scirs2_core::ndarray::*; // ✅ POLICY compliant
21//!
22//! let arr = array![[1, 2], [3, 4]];
23//! let slice = arr.slice(s![.., 0]);
24//! ```
25
26// ========================================
27// COMPLETE NDARRAY RE-EXPORT
28// ========================================
29
30// Complete ndarray 0.17 re-export (no version switching needed anymore)
31// Use ::ndarray to refer to the external crate (not this module)
32pub use ::ndarray::*;
33
34// Note: All macros (array!, s!, azip!, etc.) are already included via `pub use ::ndarray::*;`
35
36// ========================================
37// NDARRAY-RELATED CRATE RE-EXPORTS
38// ========================================
39
40#[cfg(feature = "random")]
41pub use ndarray_rand::{rand_distr as distributions, RandomExt, SamplingStrategy};
42
43// Note: ndarray_rand is compatible with both ndarray 0.16 and 0.17
44
45// NOTE: ndarray_linalg removed - using OxiBLAS via scirs2_core::linalg module
46
47#[cfg(feature = "array_stats")]
48pub use ndarray_stats::{
49 errors as stats_errors, interpolate, CorrelationExt, DeviationExt, MaybeNan, QuantileExt,
50 Sort1dExt, SummaryStatisticsExt,
51};
52
53#[cfg(feature = "array_io")]
54pub use ndarray_npy::{
55 NpzReader, NpzWriter, ReadNpyExt, ReadNpzError, ViewMutNpyExt, ViewNpyExt, WriteNpyError,
56 WriteNpyExt,
57};
58
59// ========================================
60// ENHANCED FUNCTIONALITY
61// ========================================
62
63/// Additional utilities for SciRS2 ecosystem
64pub mod utils {
65 use super::*;
66
67 /// Create an identity matrix
68 pub fn eye<A>(n: usize) -> Array2<A>
69 where
70 A: Clone + num_traits::Zero + num_traits::One,
71 {
72 let mut arr = Array2::zeros((n, n));
73 for i in 0..n {
74 arr[[i, i]] = A::one();
75 }
76 arr
77 }
78
79 /// Create a diagonal matrix from a vector
80 pub fn diag<A>(v: &Array1<A>) -> Array2<A>
81 where
82 A: Clone + num_traits::Zero,
83 {
84 let n = v.len();
85 let mut arr = Array2::zeros((n, n));
86 for i in 0..n {
87 arr[[i, i]] = v[i].clone();
88 }
89 arr
90 }
91
92 /// Check if arrays are approximately equal
93 pub fn allclose<A, D>(
94 a: &ArrayBase<impl Data<Elem = A>, D>,
95 b: &ArrayBase<impl Data<Elem = A>, D>,
96 rtol: A,
97 atol: A,
98 ) -> bool
99 where
100 A: PartialOrd
101 + std::ops::Sub<Output = A>
102 + std::ops::Mul<Output = A>
103 + std::ops::Add<Output = A>
104 + Clone,
105 D: Dimension,
106 {
107 if a.shape() != b.shape() {
108 return false;
109 }
110
111 a.iter().zip(b.iter()).all(|(a_val, b_val)| {
112 let diff = if a_val > b_val {
113 a_val.clone() - b_val.clone()
114 } else {
115 b_val.clone() - a_val.clone()
116 };
117
118 let threshold = atol.clone()
119 + rtol.clone()
120 * (if a_val > b_val {
121 a_val.clone()
122 } else {
123 b_val.clone()
124 });
125
126 diff <= threshold
127 })
128 }
129
130 /// Concatenate arrays along an axis
131 pub fn concatenate<A, D>(
132 axis: Axis,
133 arrays: &[ArrayView<A, D>],
134 ) -> Result<Array<A, D>, ShapeError>
135 where
136 A: Clone,
137 D: Dimension + RemoveAxis,
138 {
139 ndarray::concatenate(axis, arrays)
140 }
141
142 /// Stack arrays along a new axis
143 pub fn stack<A, D>(
144 axis: Axis,
145 arrays: &[ArrayView<A, D>],
146 ) -> Result<Array<A, D::Larger>, ShapeError>
147 where
148 A: Clone,
149 D: Dimension,
150 D::Larger: RemoveAxis,
151 {
152 ndarray::stack(axis, arrays)
153 }
154}
155
156// ========================================
157// COMPATIBILITY LAYER
158// ========================================
159
160/// Compatibility module for smooth migration from fragmented imports
161/// and ndarray version changes (SciRS2 POLICY compliance)
162pub mod compat {
163 pub use super::*;
164 use crate::numeric::{Float, FromPrimitive};
165
166 /// Alias for commonly used types to match existing usage patterns
167 pub type DynArray<T> = ArrayD<T>;
168 pub type Matrix<T> = Array2<T>;
169 pub type Vector<T> = Array1<T>;
170 pub type Tensor3<T> = Array3<T>;
171 pub type Tensor4<T> = Array4<T>;
172
173 /// Compatibility extensions for ndarray statistical operations
174 ///
175 /// This trait provides stable statistical operation APIs that remain consistent
176 /// across ndarray version updates, implementing the SciRS2 POLICY principle
177 /// of isolating external dependency changes to scirs2-core only.
178 ///
179 /// ## Rationale
180 ///
181 /// ndarray's statistical methods have changed across versions:
182 /// - v0.16: `.mean()` returns `Option<T>`
183 /// - v0.17: `.mean()` returns `T` directly (may be NaN for invalid operations)
184 ///
185 /// This trait provides a consistent API regardless of the underlying ndarray version.
186 ///
187 /// ## Example
188 ///
189 /// ```rust
190 /// use scirs2_core::ndarray::{Array1, compat::ArrayStatCompat};
191 ///
192 /// let data = Array1::from(vec![1.0, 2.0, 3.0]);
193 /// let mean = data.mean_or(0.0); // Stable API across ndarray versions
194 /// ```
195 pub trait ArrayStatCompat<T> {
196 /// Compute the mean of the array, returning a default value if computation fails
197 ///
198 /// This method abstracts over ndarray version differences:
199 /// - For ndarray 0.16: Unwraps the Option, using default if None
200 /// - For ndarray 0.17+: Returns the value, using default if NaN
201 fn mean_or(&self, default: T) -> T;
202
203 /// Compute the variance with optional default
204 fn var_or(&self, ddof: T, default: T) -> T;
205
206 /// Compute the standard deviation with optional default
207 fn std_or(&self, ddof: T, default: T) -> T;
208 }
209
210 impl<T, S, D> ArrayStatCompat<T> for ArrayBase<S, D>
211 where
212 T: Float + FromPrimitive,
213 S: Data<Elem = T>,
214 D: Dimension,
215 {
216 fn mean_or(&self, default: T) -> T {
217 // ndarray returns Option<T> in both 0.16 and 0.17
218 self.mean().unwrap_or(default)
219 }
220
221 fn var_or(&self, ddof: T, default: T) -> T {
222 // ndarray returns T directly (may be NaN for invalid inputs)
223 let v = self.var(ddof);
224 if v.is_nan() {
225 default
226 } else {
227 v
228 }
229 }
230
231 fn std_or(&self, ddof: T, default: T) -> T {
232 // ndarray returns T directly (may be NaN for invalid inputs)
233 let s = self.std(ddof);
234 if s.is_nan() {
235 default
236 } else {
237 s
238 }
239 }
240 }
241
242 /// Re-export from ndarray_ext for backward compatibility
243 pub use crate::ndarray_ext::{
244 broadcast_1d_to_2d,
245 broadcast_apply,
246 fancy_index_2d,
247 // Keep existing extended functionality
248 indexing,
249 is_broadcast_compatible,
250 manipulation,
251 mask_select,
252 matrix,
253 reshape_2d,
254 split_2d,
255 stack_2d,
256 stats,
257 take_2d,
258 transpose_2d,
259 where_condition,
260 };
261}
262
263// ========================================
264// PRELUDE MODULE
265// ========================================
266
267/// Prelude module with most commonly used items
268pub mod prelude {
269 pub use super::{
270 arr1,
271 arr2,
272 // Essential macros
273 array,
274 azip,
275 // Utilities
276 concatenate,
277 s,
278 stack,
279
280 stack as stack_fn,
281 // Essential types
282 Array,
283 Array0,
284 Array1,
285 Array2,
286 Array3,
287 ArrayD,
288 ArrayView,
289 ArrayView1,
290 ArrayView2,
291 ArrayViewMut,
292
293 // Common operations
294 Axis,
295 // Essential traits
296 Dimension,
297 Ix1,
298 Ix2,
299 Ix3,
300 IxDyn,
301 ScalarOperand,
302 ShapeBuilder,
303
304 Zip,
305 };
306
307 #[cfg(feature = "random")]
308 pub use super::RandomExt;
309
310 // Useful type aliases
311 pub type Matrix<T> = super::Array2<T>;
312 pub type Vector<T> = super::Array1<T>;
313}
314
315// ========================================
316// EXAMPLES MODULE
317// ========================================
318
319#[cfg(test)]
320pub mod examples {
321 //! Examples demonstrating unified ndarray access through scirs2-core
322
323 use super::*;
324
325 /// Example: Using all essential ndarray features through scirs2-core
326 ///
327 /// ```
328 /// use scirs2_core::ndarray::*;
329 ///
330 /// // Create arrays using the array! macro
331 /// let a = array![[1, 2, 3], [4, 5, 6]];
332 ///
333 /// // Use the s! macro for slicing
334 /// let row = a.slice(s![0, ..]);
335 /// let col = a.slice(s![.., 1]);
336 ///
337 /// // Use Axis for operations
338 /// let sum_axis0 = a.sum_axis(Axis(0));
339 /// let mean_axis1 = a.mean_axis(Axis(1));
340 ///
341 /// // Stack and concatenate
342 /// let b = array![[7, 8, 9], [10, 11, 12]];
343 /// let stacked = stack![Axis(0), a, b];
344 ///
345 /// // Views and iteration
346 /// for row in a.axis_iter(Axis(0)) {
347 /// println!("Row: {:?}", row);
348 /// }
349 /// ```
350 #[test]
351 fn test_complete_functionality() {
352 // Array creation
353 let a = array![[1., 2.], [3., 4.]];
354 assert_eq!(a.shape(), &[2, 2]);
355
356 // Slicing with s! macro
357 let slice = a.slice(s![.., 0]);
358 assert_eq!(slice.len(), 2);
359
360 // Mathematical operations
361 let b = &a + &a;
362 assert_eq!(b[[0, 0]], 2.);
363
364 // Axis operations
365 let sum = a.sum_axis(Axis(0));
366 assert_eq!(sum.len(), 2);
367
368 // Broadcasting
369 let c = array![1., 2.];
370 let d = &a + &c;
371 assert_eq!(d[[0, 0]], 2.);
372 }
373}
374
375// ========================================
376// MIGRATION GUIDE
377// ========================================
378
379pub mod migration_guide {
380 //! # Migration Guide: From Fragmented to Unified ndarray Access
381 //!
382 //! ## Before (Fragmented, Policy-Violating)
383 //!
384 //! ```rust,ignore
385 //! // Different files used different imports
386 //! use scirs2_autograd::ndarray::{Array1, array};
387 //! use scirs2_core::ndarray_ext::{ArrayView};
388 //! use ndarray::{s!, Axis}; // POLICY VIOLATION!
389 //! ```
390 //!
391 //! ## After (Unified, Policy-Compliant)
392 //!
393 //! ```rust,ignore
394 //! // Single, consistent import
395 //! use scirs2_core::ndarray::*;
396 //!
397 //! // Everything works:
398 //! let arr = array![[1, 2], [3, 4]];
399 //! let slice = arr.slice(s![.., 0]);
400 //! let view: ArrayView<_, _> = arr.view();
401 //! let sum = arr.sum_axis(Axis(0));
402 //! ```
403 //!
404 //! ## Benefits
405 //!
406 //! 1. **Single Import Path**: No more confusion about where to import from
407 //! 2. **Complete Functionality**: All ndarray features available
408 //! 3. **Policy Compliance**: No direct ndarray imports needed
409 //! 4. **Future-Proof**: Centralized control over array functionality
410 //!
411 //! ## Quick Reference
412 //!
413 //! | Old Import | New Import |
414 //! |------------|------------|
415 //! | `use ndarray::{Array, array}` | `use scirs2_core::ndarray::{Array, array}` |
416 //! | `use scirs2_autograd::ndarray::*` | `use scirs2_core::ndarray::*` |
417 //! | `use scirs2_core::ndarray_ext::*` | `use scirs2_core::ndarray::*` |
418 //! | `use ndarray::{s!, Axis}` | `use scirs2_core::ndarray::{s, Axis}` |
419}
420
421// Re-export compatibility traits for easy access
422pub use compat::ArrayStatCompat;
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn test_array_macro_available() {
430 let arr = array![[1, 2], [3, 4]];
431 assert_eq!(arr.shape(), &[2, 2]);
432 assert_eq!(arr[[0, 0]], 1);
433 }
434
435 #[test]
436 fn test_s_macro_available() {
437 let arr = array![[1, 2, 3], [4, 5, 6]];
438 let slice = arr.slice(s![.., 1..]);
439 assert_eq!(slice.shape(), &[2, 2]);
440 }
441
442 #[test]
443 fn test_axis_operations() {
444 let arr = array![[1., 2.], [3., 4.]];
445 let sum = arr.sum_axis(Axis(0));
446 assert_eq!(sum, array![4., 6.]);
447 }
448
449 #[test]
450 fn test_views_and_iteration() {
451 let mut arr = array![[1, 2], [3, 4]];
452
453 // Test immutable view first
454 {
455 let view: ArrayView<_, _> = arr.view();
456 for val in view.iter() {
457 assert!(*val > 0);
458 }
459 }
460
461 // Test mutable view after immutable view is dropped
462 {
463 let mut view_mut: ArrayViewMut<_, _> = arr.view_mut();
464 for val in view_mut.iter_mut() {
465 *val *= 2;
466 }
467 }
468
469 assert_eq!(arr[[0, 0]], 2);
470 }
471
472 #[test]
473 fn test_concatenate_and_stack() {
474 let a = array![[1, 2], [3, 4]];
475 let b = array![[5, 6], [7, 8]];
476
477 // Concatenate along axis 0
478 let concat = concatenate(Axis(0), &[a.view(), b.view()]).expect("Operation failed");
479 assert_eq!(concat.shape(), &[4, 2]);
480
481 // Stack along new axis
482 let stacked =
483 crate::ndarray::stack(Axis(0), &[a.view(), b.view()]).expect("Operation failed");
484 assert_eq!(stacked.shape(), &[2, 2, 2]);
485 }
486
487 #[test]
488 fn test_zip_operations() {
489 let a = array![1, 2, 3];
490 let b = array![4, 5, 6];
491 let mut c = array![0, 0, 0];
492
493 azip!((a in &a, b in &b, c in &mut c) {
494 *c = a + b;
495 });
496
497 assert_eq!(c, array![5, 7, 9]);
498 }
499
500 #[test]
501 fn test_array_stat_compat() {
502 use compat::ArrayStatCompat;
503
504 // Test mean_or
505 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
506 assert_eq!(data.mean_or(0.0), 3.0);
507
508 let empty = Array1::<f64>::from(vec![]);
509 assert_eq!(empty.mean_or(0.0), 0.0);
510
511 // Test var_or
512 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
513 let var = data.var_or(1.0, 0.0);
514 assert!(var > 0.0);
515
516 // Test std_or
517 let std = data.std_or(1.0, 0.0);
518 assert!(std > 0.0);
519 }
520}