Skip to main content

scirs2_core/
builders.rs

1//! # Ergonomic Builder Patterns for Array Construction
2//!
3//! This module provides fluent builder patterns that make common array-construction
4//! tasks more discoverable and IDE-friendly. Instead of remembering multiple
5//! constructors scattered across ndarray, users can use a single entry-point and
6//! let IDE autocomplete guide them.
7//!
8//! ## Design Goals
9//!
10//! - **Discoverability**: All construction paths live under `MatrixBuilder`,
11//!   `VectorBuilder`, and `ArrayBuilder` — easy to find in IDEs.
12//! - **No unwrap**: Every fallible operation returns `CoreResult`.
13//! - **Generic**: Works for any numeric type satisfying the appropriate traits.
14//! - **Zero-cost**: The builders are thin wrappers; all cost is in the actual
15//!   array allocation, matching what you would write by hand.
16//!
17//! ## Usage
18//!
19//! ```rust
20//! use scirs2_core::builders::{MatrixBuilder, VectorBuilder, ArrayBuilder};
21//!
22//! // 2D Matrix construction
23//! let eye3 = MatrixBuilder::<f64>::eye(3);
24//! let zeros = MatrixBuilder::<f64>::zeros(4, 4);
25//! let from_data = MatrixBuilder::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2)
26//!     .expect("correct element count");
27//! let from_fn = MatrixBuilder::from_fn(3, 3, |r, c| if r == c { 1.0f64 } else { 0.0 });
28//!
29//! // 1D Vector construction
30//! let linspace = VectorBuilder::<f64>::linspace(0.0, 1.0, 11);
31//! let arange = VectorBuilder::<f64>::arange(0.0, 5.0, 1.0);
32//! let logspace = VectorBuilder::<f64>::logspace(0.0, 3.0, 4);
33//! let from_vec = VectorBuilder::from_vec(vec![1.0, 2.0, 3.0]);
34//!
35//! // Generic multi-dim array
36//! let shaped = ArrayBuilder::<f64, _>::zeros(ndarray::Ix2(3, 4));
37//! ```
38
39use crate::error::{CoreError, CoreResult, ErrorContext};
40use ::ndarray::{Array1, Array2, ArrayD, Dimension, IntoDimension, IxDyn, ShapeError};
41use num_traits::{Float, One, Zero};
42use std::fmt::Display;
43use std::ops::MulAssign;
44
45// ============================================================================
46// MatrixBuilder — 2D Matrix Construction
47// ============================================================================
48
49/// Fluent builder for two-dimensional matrices.
50///
51/// All methods are associated functions (no `new()` required), making them
52/// trivially discoverable via IDE autocomplete when typing `MatrixBuilder::`.
53///
54/// # Type Parameter
55///
56/// `T` must be numeric. Common choices: `f64`, `f32`, `i32`, `i64`, `u64`.
57///
58/// # Examples
59///
60/// ```rust
61/// use scirs2_core::builders::MatrixBuilder;
62///
63/// // Identity matrix
64/// let eye = MatrixBuilder::<f64>::eye(3);
65/// assert_eq!(eye[[0, 0]], 1.0);
66/// assert_eq!(eye[[0, 1]], 0.0);
67///
68/// // Zeros / ones
69/// let z = MatrixBuilder::<f64>::zeros(2, 3);
70/// let o = MatrixBuilder::<f64>::ones(2, 3);
71///
72/// // From closure — computed element-by-element
73/// let computed = MatrixBuilder::from_fn(3, 3, |r, c| (r * 3 + c) as f64);
74/// assert_eq!(computed[[1, 2]], 5.0);
75/// ```
76pub struct MatrixBuilder<T>(std::marker::PhantomData<T>);
77
78impl<T> MatrixBuilder<T>
79where
80    T: Clone + Zero,
81{
82    /// Create a matrix of all zeros with shape `(rows, cols)`.
83    ///
84    /// ```rust
85    /// use scirs2_core::builders::MatrixBuilder;
86    ///
87    /// let m = MatrixBuilder::<f64>::zeros(3, 4);
88    /// assert_eq!(m.shape(), &[3, 4]);
89    /// assert_eq!(m[[0, 0]], 0.0);
90    /// ```
91    pub fn zeros(rows: usize, cols: usize) -> Array2<T> {
92        Array2::<T>::zeros((rows, cols))
93    }
94
95    /// Build a matrix from a flat `Vec` of elements in row-major order.
96    ///
97    /// Returns an error if the number of elements does not match `rows * cols`.
98    ///
99    /// ```rust
100    /// use scirs2_core::builders::MatrixBuilder;
101    ///
102    /// let m = MatrixBuilder::from_vec(vec![1.0f64, 2.0, 3.0, 4.0], 2, 2)
103    ///     .expect("element count matches");
104    /// assert_eq!(m[[0, 0]], 1.0);
105    /// assert_eq!(m[[1, 1]], 4.0);
106    /// ```
107    pub fn from_vec(data: Vec<T>, rows: usize, cols: usize) -> CoreResult<Array2<T>> {
108        if data.len() != rows * cols {
109            return Err(CoreError::InvalidInput(ErrorContext::new(format!(
110                "MatrixBuilder::from_vec: expected {} elements for a {}×{} matrix, got {}",
111                rows * cols,
112                rows,
113                cols,
114                data.len()
115            ))));
116        }
117        Array2::from_shape_vec((rows, cols), data).map_err(|e: ShapeError| {
118            CoreError::InvalidInput(ErrorContext::new(format!(
119                "MatrixBuilder::from_vec shape error: {e}"
120            )))
121        })
122    }
123}
124
125impl<T> MatrixBuilder<T>
126where
127    T: Clone + Zero + One,
128{
129    /// Create a square identity matrix of size `n × n`.
130    ///
131    /// ```rust
132    /// use scirs2_core::builders::MatrixBuilder;
133    ///
134    /// let eye = MatrixBuilder::<f64>::eye(3);
135    /// assert_eq!(eye[[2, 2]], 1.0);
136    /// assert_eq!(eye[[0, 1]], 0.0);
137    /// ```
138    pub fn eye(n: usize) -> Array2<T> {
139        let mut m = Array2::<T>::zeros((n, n));
140        for i in 0..n {
141            m[[i, i]] = T::one();
142        }
143        m
144    }
145
146    /// Create a matrix of all ones with shape `(rows, cols)`.
147    ///
148    /// ```rust
149    /// use scirs2_core::builders::MatrixBuilder;
150    ///
151    /// let m = MatrixBuilder::<f64>::ones(2, 3);
152    /// assert_eq!(m[[1, 2]], 1.0);
153    /// ```
154    pub fn ones(rows: usize, cols: usize) -> Array2<T> {
155        Array2::<T>::from_elem((rows, cols), T::one())
156    }
157}
158
159impl<T> MatrixBuilder<T>
160where
161    T: Clone,
162{
163    /// Create a matrix filled with a single constant value.
164    ///
165    /// ```rust
166    /// use scirs2_core::builders::MatrixBuilder;
167    ///
168    /// let m = MatrixBuilder::full(3, 3, 7_i32);
169    /// assert_eq!(m[[0, 0]], 7);
170    /// assert_eq!(m[[2, 2]], 7);
171    /// ```
172    pub fn full(rows: usize, cols: usize, value: T) -> Array2<T> {
173        Array2::from_elem((rows, cols), value)
174    }
175
176    /// Create a matrix where each element is produced by calling `f(row, col)`.
177    ///
178    /// ```rust
179    /// use scirs2_core::builders::MatrixBuilder;
180    ///
181    /// let m = MatrixBuilder::from_fn(3, 3, |r, c| (r * 3 + c) as f64);
182    /// assert_eq!(m[[0, 0]], 0.0);
183    /// assert_eq!(m[[2, 2]], 8.0);
184    /// ```
185    pub fn from_fn<F>(rows: usize, cols: usize, mut f: F) -> Array2<T>
186    where
187        F: FnMut(usize, usize) -> T,
188    {
189        Array2::from_shape_fn((rows, cols), |(r, c)| f(r, c))
190    }
191}
192
193impl<T> MatrixBuilder<T>
194where
195    T: Float + Clone,
196{
197    /// Create a matrix populated with uniform random values in `[0, 1)` using a seeded
198    /// ChaCha8 RNG for reproducibility.
199    ///
200    /// The `seed` parameter lets callers produce deterministic results in tests
201    /// and benchmarks while still getting varied values in production by passing
202    /// different seeds.
203    ///
204    /// ```rust
205    /// use scirs2_core::builders::MatrixBuilder;
206    ///
207    /// let m = MatrixBuilder::<f64>::rand(3, 3, 42);
208    /// assert_eq!(m.shape(), &[3, 3]);
209    /// // All values should be in [0, 1)
210    /// assert!(m.iter().all(|&v| v >= 0.0 && v < 1.0));
211    /// ```
212    pub fn rand(rows: usize, cols: usize, seed: u64) -> Array2<T> {
213        use rand::SeedableRng;
214        use rand_chacha::ChaCha8Rng;
215
216        let mut rng = ChaCha8Rng::seed_from_u64(seed);
217        Array2::from_shape_fn((rows, cols), |_| {
218            // Generate a uniform f64 in [0, 1) and cast to T
219            use rand::Rng;
220            let v: f64 = rng.random();
221            T::from(v).unwrap_or_else(T::zero)
222        })
223    }
224
225    /// Create a matrix populated with standard normal (`N(0, 1)`) random values.
226    ///
227    /// ```rust
228    /// use scirs2_core::builders::MatrixBuilder;
229    ///
230    /// let m = MatrixBuilder::<f64>::randn(4, 4, 0);
231    /// assert_eq!(m.shape(), &[4, 4]);
232    /// ```
233    pub fn randn(rows: usize, cols: usize, seed: u64) -> Array2<T> {
234        use rand::SeedableRng;
235        use rand_chacha::ChaCha8Rng;
236        use rand_distr::{Distribution, StandardNormal};
237
238        let mut rng = ChaCha8Rng::seed_from_u64(seed);
239        Array2::from_shape_fn((rows, cols), |_| {
240            let v: f64 = StandardNormal.sample(&mut rng);
241            T::from(v).unwrap_or_else(T::zero)
242        })
243    }
244}
245
246// ============================================================================
247// VectorBuilder — 1D Array Construction
248// ============================================================================
249
250/// Fluent builder for one-dimensional arrays (vectors).
251///
252/// Provides NumPy-like constructors (`linspace`, `arange`, `logspace`) as well
253/// as the standard `zeros`, `ones`, `from_vec`, and `from_fn` constructors.
254///
255/// # Examples
256///
257/// ```rust
258/// use scirs2_core::builders::VectorBuilder;
259///
260/// let v = VectorBuilder::<f64>::linspace(0.0, 1.0, 5);
261/// assert!((v[0] - 0.0).abs() < 1e-12);
262/// assert!((v[4] - 1.0).abs() < 1e-12);
263///
264/// let r = VectorBuilder::<f64>::arange(0.0, 5.0, 1.0);
265/// assert_eq!(r.len(), 5);
266/// ```
267pub struct VectorBuilder<T>(std::marker::PhantomData<T>);
268
269impl<T> VectorBuilder<T>
270where
271    T: Clone + Zero,
272{
273    /// Create a vector of all zeros with `n` elements.
274    ///
275    /// ```rust
276    /// use scirs2_core::builders::VectorBuilder;
277    ///
278    /// let v = VectorBuilder::<f64>::zeros(5);
279    /// assert_eq!(v.len(), 5);
280    /// assert_eq!(v[3], 0.0);
281    /// ```
282    pub fn zeros(n: usize) -> Array1<T> {
283        Array1::<T>::zeros(n)
284    }
285
286    /// Build a vector from a `Vec`.
287    ///
288    /// ```rust
289    /// use scirs2_core::builders::VectorBuilder;
290    ///
291    /// let v = VectorBuilder::from_vec(vec![1.0_f64, 2.0, 3.0]);
292    /// assert_eq!(v[1], 2.0);
293    /// ```
294    pub fn from_vec(data: Vec<T>) -> Array1<T> {
295        Array1::from(data)
296    }
297}
298
299impl<T> VectorBuilder<T>
300where
301    T: Clone + Zero + One,
302{
303    /// Create a vector of all ones with `n` elements.
304    ///
305    /// ```rust
306    /// use scirs2_core::builders::VectorBuilder;
307    ///
308    /// let v = VectorBuilder::<f64>::ones(4);
309    /// assert_eq!(v[2], 1.0);
310    /// ```
311    pub fn ones(n: usize) -> Array1<T> {
312        Array1::from_elem(n, T::one())
313    }
314}
315
316impl<T> VectorBuilder<T>
317where
318    T: Clone,
319{
320    /// Create a vector where element `i` is produced by `f(i)`.
321    ///
322    /// ```rust
323    /// use scirs2_core::builders::VectorBuilder;
324    ///
325    /// let squares = VectorBuilder::from_fn(5, |i| (i * i) as f64);
326    /// assert_eq!(squares[3], 9.0);
327    /// ```
328    pub fn from_fn<F>(n: usize, mut f: F) -> Array1<T>
329    where
330        F: FnMut(usize) -> T,
331    {
332        Array1::from_shape_fn(n, |i| f(i))
333    }
334
335    /// Create a vector filled with a constant value.
336    ///
337    /// ```rust
338    /// use scirs2_core::builders::VectorBuilder;
339    ///
340    /// let v = VectorBuilder::full(3, 7_i32);
341    /// assert_eq!(v[0], 7);
342    /// ```
343    pub fn full(n: usize, value: T) -> Array1<T> {
344        Array1::from_elem(n, value)
345    }
346}
347
348impl<T> VectorBuilder<T>
349where
350    T: Float + Display + Clone + MulAssign,
351{
352    /// Create `n` evenly spaced values from `start` to `stop` (inclusive).
353    ///
354    /// This is the analogue of NumPy's `np.linspace`.
355    ///
356    /// ```rust
357    /// use scirs2_core::builders::VectorBuilder;
358    ///
359    /// let v = VectorBuilder::<f64>::linspace(0.0, 4.0, 5);
360    /// assert!((v[0] - 0.0).abs() < 1e-12);
361    /// assert!((v[2] - 2.0).abs() < 1e-12);
362    /// assert!((v[4] - 4.0).abs() < 1e-12);
363    /// ```
364    pub fn linspace(start: T, stop: T, n: usize) -> Array1<T> {
365        if n == 0 {
366            return Array1::from(vec![]);
367        }
368        if n == 1 {
369            return Array1::from(vec![start]);
370        }
371        let steps = T::from(n - 1).unwrap_or_else(T::one);
372        Array1::from_shape_fn(n, |i| {
373            let t = T::from(i).unwrap_or_else(T::zero);
374            start + (stop - start) * (t / steps)
375        })
376    }
377
378    /// Create values from `start` up to (but not including) `stop` with step `step`.
379    ///
380    /// This is the analogue of NumPy's `np.arange`.
381    ///
382    /// ```rust
383    /// use scirs2_core::builders::VectorBuilder;
384    ///
385    /// let v = VectorBuilder::<f64>::arange(0.0, 5.0, 1.0);
386    /// assert_eq!(v.len(), 5);
387    /// assert!((v[0] - 0.0).abs() < 1e-12);
388    /// assert!((v[4] - 4.0).abs() < 1e-12);
389    ///
390    /// // Fractional step
391    /// let v2 = VectorBuilder::<f64>::arange(0.0, 1.0, 0.5);
392    /// assert_eq!(v2.len(), 2);
393    /// ```
394    pub fn arange(start: T, stop: T, step: T) -> Array1<T> {
395        if step == T::zero() || (stop - start).signum() != step.signum() {
396            return Array1::from(vec![]);
397        }
398        let n_float = ((stop - start) / step).ceil();
399        let n = n_float.to_usize().unwrap_or(0).max(0);
400        Array1::from_shape_fn(n, |i| start + step * T::from(i).unwrap_or_else(T::zero))
401    }
402
403    /// Create `n` values evenly spaced on a logarithmic scale.
404    ///
405    /// The values span from `10^start` to `10^stop` (inclusive), analogous to
406    /// NumPy's `np.logspace(start, stop, n, base=10)`.
407    ///
408    /// ```rust
409    /// use scirs2_core::builders::VectorBuilder;
410    ///
411    /// // 4 values from 10^0 = 1 to 10^3 = 1000
412    /// let v = VectorBuilder::<f64>::logspace(0.0, 3.0, 4);
413    /// assert!((v[0] - 1.0).abs() < 1e-10);
414    /// assert!((v[3] - 1000.0).abs() < 1e-8);
415    /// ```
416    pub fn logspace(start: T, stop: T, n: usize) -> Array1<T> {
417        let lin = Self::linspace(start, stop, n);
418        lin.mapv(|x| T::from(10.0_f64).unwrap_or_else(T::one).powf(x))
419    }
420
421    /// Create `n` uniform random values in `[0, 1)` using a seeded ChaCha8 RNG.
422    ///
423    /// ```rust
424    /// use scirs2_core::builders::VectorBuilder;
425    ///
426    /// let v = VectorBuilder::<f64>::rand(5, 42);
427    /// assert_eq!(v.len(), 5);
428    /// assert!(v.iter().all(|&x| x >= 0.0 && x < 1.0));
429    /// ```
430    pub fn rand(n: usize, seed: u64) -> Array1<T> {
431        use rand::SeedableRng;
432        use rand_chacha::ChaCha8Rng;
433
434        let mut rng = ChaCha8Rng::seed_from_u64(seed);
435        Array1::from_shape_fn(n, |_| {
436            use rand::Rng;
437            let v: f64 = rng.random();
438            T::from(v).unwrap_or_else(T::zero)
439        })
440    }
441
442    /// Create `n` standard-normal random values using a seeded ChaCha8 RNG.
443    ///
444    /// ```rust
445    /// use scirs2_core::builders::VectorBuilder;
446    ///
447    /// let v = VectorBuilder::<f64>::randn(5, 0);
448    /// assert_eq!(v.len(), 5);
449    /// ```
450    pub fn randn(n: usize, seed: u64) -> Array1<T> {
451        use rand::SeedableRng;
452        use rand_chacha::ChaCha8Rng;
453        use rand_distr::{Distribution, StandardNormal};
454
455        let mut rng = ChaCha8Rng::seed_from_u64(seed);
456        Array1::from_shape_fn(n, |_| {
457            let v: f64 = StandardNormal.sample(&mut rng);
458            T::from(v).unwrap_or_else(T::zero)
459        })
460    }
461}
462
463// ============================================================================
464// ArrayBuilder — Generic N-dimensional Array Construction
465// ============================================================================
466
467/// Generic builder for N-dimensional arrays.
468///
469/// Where `MatrixBuilder` targets exactly 2D and `VectorBuilder` targets exactly 1D,
470/// `ArrayBuilder` works with any [`ndarray::Dimension`] and is useful when the
471/// shape is determined at runtime.
472///
473/// # Examples
474///
475/// ```rust
476/// use scirs2_core::builders::ArrayBuilder;
477///
478/// let a2 = ArrayBuilder::<f64, _>::zeros(ndarray::Ix2(3, 4));
479/// assert_eq!(a2.shape(), &[3, 4]);
480///
481/// let a3 = ArrayBuilder::<f64, _>::zeros(ndarray::Ix3(2, 3, 4));
482/// assert_eq!(a3.shape(), &[2, 3, 4]);
483///
484/// // Dynamic dimension
485/// let ad = ArrayBuilder::<f64, ndarray::IxDyn>::zeros_dyn(&[2, 3, 4]);
486/// assert_eq!(ad.shape(), &[2, 3, 4]);
487/// ```
488pub struct ArrayBuilder<T, D>(std::marker::PhantomData<(T, D)>);
489
490impl<T, D> ArrayBuilder<T, D>
491where
492    T: Clone + Zero,
493    D: Dimension,
494{
495    /// Create a zeros array with the given shape.
496    ///
497    /// ```rust
498    /// use scirs2_core::builders::ArrayBuilder;
499    ///
500    /// let a = ArrayBuilder::<f64, _>::zeros(ndarray::Ix2(3, 4));
501    /// assert_eq!(a.shape(), &[3, 4]);
502    /// ```
503    pub fn zeros<Sh>(shape: Sh) -> ::ndarray::Array<T, D>
504    where
505        Sh: IntoDimension<Dim = D>,
506    {
507        ::ndarray::Array::zeros(shape)
508    }
509
510    /// Create an array filled with a constant value.
511    ///
512    /// ```rust
513    /// use scirs2_core::builders::ArrayBuilder;
514    ///
515    /// let a = ArrayBuilder::<i32, _>::full(ndarray::Ix2(2, 3), 7);
516    /// assert_eq!(a[[0, 0]], 7);
517    /// ```
518    pub fn full<Sh>(shape: Sh, value: T) -> ::ndarray::Array<T, D>
519    where
520        Sh: IntoDimension<Dim = D>,
521    {
522        ::ndarray::Array::from_elem(shape, value)
523    }
524
525    /// Create an array where each element is produced by a closure receiving the
526    /// dimension pattern (e.g. `(row, col)` for 2D, `(i, j, k)` for 3D, etc.).
527    ///
528    /// ```rust
529    /// use scirs2_core::builders::ArrayBuilder;
530    ///
531    /// // 3×3 matrix: element = row + col
532    /// let a = ArrayBuilder::<usize, ndarray::Ix2>::from_fn(
533    ///     ndarray::Ix2(3, 3),
534    ///     |(r, c)| r + c,
535    /// );
536    /// assert_eq!(a[[2, 2]], 4);
537    /// ```
538    pub fn from_fn<Sh, F>(shape: Sh, f: F) -> ::ndarray::Array<T, D>
539    where
540        Sh: IntoDimension<Dim = D>,
541        F: FnMut(D::Pattern) -> T,
542    {
543        ::ndarray::Array::from_shape_fn(shape, f)
544    }
545
546    /// Build an array from a flat `Vec` of elements in C-order (row-major).
547    ///
548    /// Returns a `CoreError` if the element count does not match the given shape.
549    ///
550    /// ```rust
551    /// use scirs2_core::builders::ArrayBuilder;
552    ///
553    /// let a = ArrayBuilder::<f64, ndarray::Ix2>::from_vec(
554    ///     vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
555    ///     ndarray::Ix2(2, 3),
556    /// ).expect("element count matches");
557    /// assert_eq!(a[[1, 2]], 6.0);
558    /// ```
559    pub fn from_vec<Sh>(data: Vec<T>, shape: Sh) -> CoreResult<::ndarray::Array<T, D>>
560    where
561        Sh: IntoDimension<Dim = D>,
562    {
563        ::ndarray::Array::from_shape_vec(shape, data).map_err(|e: ShapeError| {
564            CoreError::InvalidInput(ErrorContext::new(format!(
565                "ArrayBuilder::from_vec shape error: {e}"
566            )))
567        })
568    }
569}
570
571impl<T> ArrayBuilder<T, IxDyn>
572where
573    T: Clone + Zero,
574{
575    /// Create a dynamic-dimensional zeros array from a runtime shape slice.
576    ///
577    /// ```rust
578    /// use scirs2_core::builders::ArrayBuilder;
579    ///
580    /// let a = ArrayBuilder::<f64, ndarray::IxDyn>::zeros_dyn(&[2, 3, 4]);
581    /// assert_eq!(a.ndim(), 3);
582    /// assert_eq!(a.shape(), &[2, 3, 4]);
583    /// ```
584    pub fn zeros_dyn(shape: &[usize]) -> ArrayD<T> {
585        ArrayD::zeros(IxDyn(shape))
586    }
587
588    /// Create a dynamic-dimensional array filled with `value`.
589    pub fn full_dyn(shape: &[usize], value: T) -> ArrayD<T> {
590        ArrayD::from_elem(IxDyn(shape), value)
591    }
592}
593
594// ============================================================================
595// Tests
596// ============================================================================
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601    use approx::assert_abs_diff_eq;
602
603    // --- MatrixBuilder tests ---
604
605    #[test]
606    fn test_matrix_zeros() {
607        let m = MatrixBuilder::<f64>::zeros(3, 4);
608        assert_eq!(m.shape(), &[3, 4]);
609        assert!(m.iter().all(|&v| v == 0.0));
610    }
611
612    #[test]
613    fn test_matrix_ones() {
614        let m = MatrixBuilder::<f64>::ones(2, 5);
615        assert_eq!(m.shape(), &[2, 5]);
616        assert!(m.iter().all(|&v| v == 1.0));
617    }
618
619    #[test]
620    fn test_matrix_eye() {
621        let eye = MatrixBuilder::<f64>::eye(3);
622        assert_eq!(eye.shape(), &[3, 3]);
623        for i in 0..3 {
624            for j in 0..3 {
625                let expected = if i == j { 1.0 } else { 0.0 };
626                assert_abs_diff_eq!(eye[[i, j]], expected);
627            }
628        }
629    }
630
631    #[test]
632    fn test_matrix_from_vec() {
633        let m = MatrixBuilder::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0], 2, 2)
634            .expect("element count should match");
635        assert_eq!(m[[0, 0]], 1.0);
636        assert_eq!(m[[0, 1]], 2.0);
637        assert_eq!(m[[1, 0]], 3.0);
638        assert_eq!(m[[1, 1]], 4.0);
639    }
640
641    #[test]
642    fn test_matrix_from_vec_error() {
643        // Wrong element count → error
644        let result = MatrixBuilder::<f64>::from_vec(vec![1.0, 2.0, 3.0], 2, 2);
645        assert!(result.is_err());
646    }
647
648    #[test]
649    fn test_matrix_from_fn() {
650        let m = MatrixBuilder::from_fn(3, 3, |r, c| (r * 3 + c) as f64);
651        for r in 0..3 {
652            for c in 0..3 {
653                assert_abs_diff_eq!(m[[r, c]], (r * 3 + c) as f64);
654            }
655        }
656    }
657
658    #[test]
659    fn test_matrix_full() {
660        let m = MatrixBuilder::full(3, 3, 42_i32);
661        assert!(m.iter().all(|&v| v == 42));
662    }
663
664    #[test]
665    fn test_matrix_rand() {
666        let m = MatrixBuilder::<f64>::rand(10, 10, 99);
667        assert_eq!(m.shape(), &[10, 10]);
668        assert!(m.iter().all(|&v| v >= 0.0 && v < 1.0));
669        // Deterministic: same seed → same values
670        let m2 = MatrixBuilder::<f64>::rand(10, 10, 99);
671        assert_eq!(m, m2);
672    }
673
674    #[test]
675    fn test_matrix_randn() {
676        let m = MatrixBuilder::<f64>::randn(100, 100, 0);
677        // Mean should be roughly 0, std roughly 1
678        let mean = m.mean().expect("non-empty");
679        assert!(mean.abs() < 0.5, "mean={mean}");
680    }
681
682    // --- VectorBuilder tests ---
683
684    #[test]
685    fn test_vector_zeros() {
686        let v = VectorBuilder::<f64>::zeros(5);
687        assert_eq!(v.len(), 5);
688        assert!(v.iter().all(|&x| x == 0.0));
689    }
690
691    #[test]
692    fn test_vector_ones() {
693        let v = VectorBuilder::<f64>::ones(4);
694        assert_eq!(v.len(), 4);
695        assert!(v.iter().all(|&x| x == 1.0));
696    }
697
698    #[test]
699    fn test_vector_from_vec() {
700        let v = VectorBuilder::from_vec(vec![10.0_f64, 20.0, 30.0]);
701        assert_eq!(v.len(), 3);
702        assert_eq!(v[1], 20.0);
703    }
704
705    #[test]
706    fn test_vector_from_fn() {
707        let v = VectorBuilder::from_fn(5, |i| i as f64 * 2.0);
708        assert_abs_diff_eq!(v[3], 6.0);
709    }
710
711    #[test]
712    fn test_vector_full() {
713        let v = VectorBuilder::full(4, 3.14_f64);
714        assert!(v.iter().all(|&x| (x - 3.14).abs() < 1e-12));
715    }
716
717    #[test]
718    fn test_vector_linspace() {
719        let v = VectorBuilder::<f64>::linspace(0.0, 4.0, 5);
720        assert_eq!(v.len(), 5);
721        for (i, &val) in v.iter().enumerate() {
722            assert_abs_diff_eq!(val, i as f64, epsilon = 1e-12);
723        }
724    }
725
726    #[test]
727    fn test_vector_linspace_single() {
728        let v = VectorBuilder::<f64>::linspace(3.0, 3.0, 1);
729        assert_eq!(v.len(), 1);
730        assert_abs_diff_eq!(v[0], 3.0);
731    }
732
733    #[test]
734    fn test_vector_linspace_empty() {
735        let v = VectorBuilder::<f64>::linspace(0.0, 1.0, 0);
736        assert_eq!(v.len(), 0);
737    }
738
739    #[test]
740    fn test_vector_arange() {
741        let v = VectorBuilder::<f64>::arange(0.0, 5.0, 1.0);
742        assert_eq!(v.len(), 5);
743        for (i, &val) in v.iter().enumerate() {
744            assert_abs_diff_eq!(val, i as f64, epsilon = 1e-12);
745        }
746    }
747
748    #[test]
749    fn test_vector_arange_fractional() {
750        let v = VectorBuilder::<f64>::arange(0.0, 1.0, 0.5);
751        assert_eq!(v.len(), 2);
752        assert_abs_diff_eq!(v[0], 0.0, epsilon = 1e-12);
753        assert_abs_diff_eq!(v[1], 0.5, epsilon = 1e-12);
754    }
755
756    #[test]
757    fn test_vector_arange_empty() {
758        // step 0 → empty
759        let v = VectorBuilder::<f64>::arange(0.0, 5.0, 0.0);
760        assert_eq!(v.len(), 0);
761        // wrong direction → empty
762        let v2 = VectorBuilder::<f64>::arange(5.0, 0.0, 1.0);
763        assert_eq!(v2.len(), 0);
764    }
765
766    #[test]
767    fn test_vector_logspace() {
768        let v = VectorBuilder::<f64>::logspace(0.0, 3.0, 4);
769        assert_eq!(v.len(), 4);
770        assert_abs_diff_eq!(v[0], 1.0, epsilon = 1e-10);
771        assert_abs_diff_eq!(v[1], 10.0, epsilon = 1e-8);
772        assert_abs_diff_eq!(v[2], 100.0, epsilon = 1e-6);
773        assert_abs_diff_eq!(v[3], 1000.0, epsilon = 1e-4);
774    }
775
776    #[test]
777    fn test_vector_rand() {
778        let v = VectorBuilder::<f64>::rand(20, 7);
779        assert_eq!(v.len(), 20);
780        assert!(v.iter().all(|&x| x >= 0.0 && x < 1.0));
781        // Determinism
782        let v2 = VectorBuilder::<f64>::rand(20, 7);
783        assert_eq!(v, v2);
784    }
785
786    #[test]
787    fn test_vector_randn() {
788        let v = VectorBuilder::<f64>::randn(1000, 123);
789        assert_eq!(v.len(), 1000);
790        let mean = v.mean().expect("non-empty");
791        assert!(mean.abs() < 0.2, "mean={mean}");
792    }
793
794    // --- ArrayBuilder tests ---
795
796    #[test]
797    fn test_array_builder_zeros_2d() {
798        let a = ArrayBuilder::<f64, ::ndarray::Ix2>::zeros(::ndarray::Ix2(3, 4));
799        assert_eq!(a.shape(), &[3, 4]);
800        assert!(a.iter().all(|&v| v == 0.0));
801    }
802
803    #[test]
804    fn test_array_builder_zeros_3d() {
805        let a = ArrayBuilder::<f64, ::ndarray::Ix3>::zeros(::ndarray::Ix3(2, 3, 4));
806        assert_eq!(a.shape(), &[2, 3, 4]);
807    }
808
809    #[test]
810    fn test_array_builder_zeros_dyn() {
811        let a = ArrayBuilder::<f64, ::ndarray::IxDyn>::zeros_dyn(&[2, 3, 4]);
812        assert_eq!(a.ndim(), 3);
813        assert_eq!(a.shape(), &[2, 3, 4]);
814    }
815
816    #[test]
817    fn test_array_builder_full() {
818        let a = ArrayBuilder::<i32, ::ndarray::Ix2>::full(::ndarray::Ix2(3, 3), 7);
819        assert!(a.iter().all(|&v| v == 7));
820    }
821
822    #[test]
823    fn test_array_builder_from_vec_ok() {
824        let a = ArrayBuilder::<f64, ::ndarray::Ix2>::from_vec(
825            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
826            ::ndarray::Ix2(2, 3),
827        )
828        .expect("valid shape");
829        assert_eq!(a[[1, 2]], 6.0);
830    }
831
832    #[test]
833    fn test_array_builder_from_vec_err() {
834        let result = ArrayBuilder::<f64, ::ndarray::Ix2>::from_vec(
835            vec![1.0, 2.0, 3.0],
836            ::ndarray::Ix2(2, 3), // needs 6 elements
837        );
838        assert!(result.is_err());
839    }
840}