scirs2_core/builders.rs
1//! # Ergonomic Builder Patterns for Array Construction
2//!
3//! This module provides fluent builder patterns that make common array-construction
4//! tasks more discoverable and IDE-friendly. Instead of remembering multiple
5//! constructors scattered across ndarray, users can use a single entry-point and
6//! let IDE autocomplete guide them.
7//!
8//! ## Design Goals
9//!
10//! - **Discoverability**: All construction paths live under `MatrixBuilder`,
11//! `VectorBuilder`, and `ArrayBuilder` — easy to find in IDEs.
12//! - **No unwrap**: Every fallible operation returns `CoreResult`.
13//! - **Generic**: Works for any numeric type satisfying the appropriate traits.
14//! - **Zero-cost**: The builders are thin wrappers; all cost is in the actual
15//! array allocation, matching what you would write by hand.
16//!
17//! ## Usage
18//!
19//! ```rust
20//! use scirs2_core::builders::{MatrixBuilder, VectorBuilder, ArrayBuilder};
21//!
22//! // 2D Matrix construction
23//! let eye3 = MatrixBuilder::<f64>::eye(3);
24//! let zeros = MatrixBuilder::<f64>::zeros(4, 4);
25//! let from_data = MatrixBuilder::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2)
26//! .expect("correct element count");
27//! let from_fn = MatrixBuilder::from_fn(3, 3, |r, c| if r == c { 1.0f64 } else { 0.0 });
28//!
29//! // 1D Vector construction
30//! let linspace = VectorBuilder::<f64>::linspace(0.0, 1.0, 11);
31//! let arange = VectorBuilder::<f64>::arange(0.0, 5.0, 1.0);
32//! let logspace = VectorBuilder::<f64>::logspace(0.0, 3.0, 4);
33//! let from_vec = VectorBuilder::from_vec(vec![1.0, 2.0, 3.0]);
34//!
35//! // Generic multi-dim array
36//! let shaped = ArrayBuilder::<f64, _>::zeros(ndarray::Ix2(3, 4));
37//! ```
38
39use crate::error::{CoreError, CoreResult, ErrorContext};
40use ::ndarray::{Array1, Array2, ArrayD, Dimension, IntoDimension, IxDyn, ShapeError};
41use num_traits::{Float, One, Zero};
42use std::fmt::Display;
43use std::ops::MulAssign;
44
45// ============================================================================
46// MatrixBuilder — 2D Matrix Construction
47// ============================================================================
48
49/// Fluent builder for two-dimensional matrices.
50///
51/// All methods are associated functions (no `new()` required), making them
52/// trivially discoverable via IDE autocomplete when typing `MatrixBuilder::`.
53///
54/// # Type Parameter
55///
56/// `T` must be numeric. Common choices: `f64`, `f32`, `i32`, `i64`, `u64`.
57///
58/// # Examples
59///
60/// ```rust
61/// use scirs2_core::builders::MatrixBuilder;
62///
63/// // Identity matrix
64/// let eye = MatrixBuilder::<f64>::eye(3);
65/// assert_eq!(eye[[0, 0]], 1.0);
66/// assert_eq!(eye[[0, 1]], 0.0);
67///
68/// // Zeros / ones
69/// let z = MatrixBuilder::<f64>::zeros(2, 3);
70/// let o = MatrixBuilder::<f64>::ones(2, 3);
71///
72/// // From closure — computed element-by-element
73/// let computed = MatrixBuilder::from_fn(3, 3, |r, c| (r * 3 + c) as f64);
74/// assert_eq!(computed[[1, 2]], 5.0);
75/// ```
76pub struct MatrixBuilder<T>(std::marker::PhantomData<T>);
77
78impl<T> MatrixBuilder<T>
79where
80 T: Clone + Zero,
81{
82 /// Create a matrix of all zeros with shape `(rows, cols)`.
83 ///
84 /// ```rust
85 /// use scirs2_core::builders::MatrixBuilder;
86 ///
87 /// let m = MatrixBuilder::<f64>::zeros(3, 4);
88 /// assert_eq!(m.shape(), &[3, 4]);
89 /// assert_eq!(m[[0, 0]], 0.0);
90 /// ```
91 pub fn zeros(rows: usize, cols: usize) -> Array2<T> {
92 Array2::<T>::zeros((rows, cols))
93 }
94
95 /// Build a matrix from a flat `Vec` of elements in row-major order.
96 ///
97 /// Returns an error if the number of elements does not match `rows * cols`.
98 ///
99 /// ```rust
100 /// use scirs2_core::builders::MatrixBuilder;
101 ///
102 /// let m = MatrixBuilder::from_vec(vec![1.0f64, 2.0, 3.0, 4.0], 2, 2)
103 /// .expect("element count matches");
104 /// assert_eq!(m[[0, 0]], 1.0);
105 /// assert_eq!(m[[1, 1]], 4.0);
106 /// ```
107 pub fn from_vec(data: Vec<T>, rows: usize, cols: usize) -> CoreResult<Array2<T>> {
108 if data.len() != rows * cols {
109 return Err(CoreError::InvalidInput(ErrorContext::new(format!(
110 "MatrixBuilder::from_vec: expected {} elements for a {}×{} matrix, got {}",
111 rows * cols,
112 rows,
113 cols,
114 data.len()
115 ))));
116 }
117 Array2::from_shape_vec((rows, cols), data).map_err(|e: ShapeError| {
118 CoreError::InvalidInput(ErrorContext::new(format!(
119 "MatrixBuilder::from_vec shape error: {e}"
120 )))
121 })
122 }
123}
124
125impl<T> MatrixBuilder<T>
126where
127 T: Clone + Zero + One,
128{
129 /// Create a square identity matrix of size `n × n`.
130 ///
131 /// ```rust
132 /// use scirs2_core::builders::MatrixBuilder;
133 ///
134 /// let eye = MatrixBuilder::<f64>::eye(3);
135 /// assert_eq!(eye[[2, 2]], 1.0);
136 /// assert_eq!(eye[[0, 1]], 0.0);
137 /// ```
138 pub fn eye(n: usize) -> Array2<T> {
139 let mut m = Array2::<T>::zeros((n, n));
140 for i in 0..n {
141 m[[i, i]] = T::one();
142 }
143 m
144 }
145
146 /// Create a matrix of all ones with shape `(rows, cols)`.
147 ///
148 /// ```rust
149 /// use scirs2_core::builders::MatrixBuilder;
150 ///
151 /// let m = MatrixBuilder::<f64>::ones(2, 3);
152 /// assert_eq!(m[[1, 2]], 1.0);
153 /// ```
154 pub fn ones(rows: usize, cols: usize) -> Array2<T> {
155 Array2::<T>::from_elem((rows, cols), T::one())
156 }
157}
158
159impl<T> MatrixBuilder<T>
160where
161 T: Clone,
162{
163 /// Create a matrix filled with a single constant value.
164 ///
165 /// ```rust
166 /// use scirs2_core::builders::MatrixBuilder;
167 ///
168 /// let m = MatrixBuilder::full(3, 3, 7_i32);
169 /// assert_eq!(m[[0, 0]], 7);
170 /// assert_eq!(m[[2, 2]], 7);
171 /// ```
172 pub fn full(rows: usize, cols: usize, value: T) -> Array2<T> {
173 Array2::from_elem((rows, cols), value)
174 }
175
176 /// Create a matrix where each element is produced by calling `f(row, col)`.
177 ///
178 /// ```rust
179 /// use scirs2_core::builders::MatrixBuilder;
180 ///
181 /// let m = MatrixBuilder::from_fn(3, 3, |r, c| (r * 3 + c) as f64);
182 /// assert_eq!(m[[0, 0]], 0.0);
183 /// assert_eq!(m[[2, 2]], 8.0);
184 /// ```
185 pub fn from_fn<F>(rows: usize, cols: usize, mut f: F) -> Array2<T>
186 where
187 F: FnMut(usize, usize) -> T,
188 {
189 Array2::from_shape_fn((rows, cols), |(r, c)| f(r, c))
190 }
191}
192
193impl<T> MatrixBuilder<T>
194where
195 T: Float + Clone,
196{
197 /// Create a matrix populated with uniform random values in `[0, 1)` using a seeded
198 /// ChaCha8 RNG for reproducibility.
199 ///
200 /// The `seed` parameter lets callers produce deterministic results in tests
201 /// and benchmarks while still getting varied values in production by passing
202 /// different seeds.
203 ///
204 /// ```rust
205 /// use scirs2_core::builders::MatrixBuilder;
206 ///
207 /// let m = MatrixBuilder::<f64>::rand(3, 3, 42);
208 /// assert_eq!(m.shape(), &[3, 3]);
209 /// // All values should be in [0, 1)
210 /// assert!(m.iter().all(|&v| v >= 0.0 && v < 1.0));
211 /// ```
212 pub fn rand(rows: usize, cols: usize, seed: u64) -> Array2<T> {
213 use rand::SeedableRng;
214 use rand_chacha::ChaCha8Rng;
215
216 let mut rng = ChaCha8Rng::seed_from_u64(seed);
217 Array2::from_shape_fn((rows, cols), |_| {
218 // Generate a uniform f64 in [0, 1) and cast to T
219 use rand::Rng;
220 let v: f64 = rng.random();
221 T::from(v).unwrap_or_else(T::zero)
222 })
223 }
224
225 /// Create a matrix populated with standard normal (`N(0, 1)`) random values.
226 ///
227 /// ```rust
228 /// use scirs2_core::builders::MatrixBuilder;
229 ///
230 /// let m = MatrixBuilder::<f64>::randn(4, 4, 0);
231 /// assert_eq!(m.shape(), &[4, 4]);
232 /// ```
233 pub fn randn(rows: usize, cols: usize, seed: u64) -> Array2<T> {
234 use rand::SeedableRng;
235 use rand_chacha::ChaCha8Rng;
236 use rand_distr::{Distribution, StandardNormal};
237
238 let mut rng = ChaCha8Rng::seed_from_u64(seed);
239 Array2::from_shape_fn((rows, cols), |_| {
240 let v: f64 = StandardNormal.sample(&mut rng);
241 T::from(v).unwrap_or_else(T::zero)
242 })
243 }
244}
245
246// ============================================================================
247// VectorBuilder — 1D Array Construction
248// ============================================================================
249
250/// Fluent builder for one-dimensional arrays (vectors).
251///
252/// Provides NumPy-like constructors (`linspace`, `arange`, `logspace`) as well
253/// as the standard `zeros`, `ones`, `from_vec`, and `from_fn` constructors.
254///
255/// # Examples
256///
257/// ```rust
258/// use scirs2_core::builders::VectorBuilder;
259///
260/// let v = VectorBuilder::<f64>::linspace(0.0, 1.0, 5);
261/// assert!((v[0] - 0.0).abs() < 1e-12);
262/// assert!((v[4] - 1.0).abs() < 1e-12);
263///
264/// let r = VectorBuilder::<f64>::arange(0.0, 5.0, 1.0);
265/// assert_eq!(r.len(), 5);
266/// ```
267pub struct VectorBuilder<T>(std::marker::PhantomData<T>);
268
269impl<T> VectorBuilder<T>
270where
271 T: Clone + Zero,
272{
273 /// Create a vector of all zeros with `n` elements.
274 ///
275 /// ```rust
276 /// use scirs2_core::builders::VectorBuilder;
277 ///
278 /// let v = VectorBuilder::<f64>::zeros(5);
279 /// assert_eq!(v.len(), 5);
280 /// assert_eq!(v[3], 0.0);
281 /// ```
282 pub fn zeros(n: usize) -> Array1<T> {
283 Array1::<T>::zeros(n)
284 }
285
286 /// Build a vector from a `Vec`.
287 ///
288 /// ```rust
289 /// use scirs2_core::builders::VectorBuilder;
290 ///
291 /// let v = VectorBuilder::from_vec(vec![1.0_f64, 2.0, 3.0]);
292 /// assert_eq!(v[1], 2.0);
293 /// ```
294 pub fn from_vec(data: Vec<T>) -> Array1<T> {
295 Array1::from(data)
296 }
297}
298
299impl<T> VectorBuilder<T>
300where
301 T: Clone + Zero + One,
302{
303 /// Create a vector of all ones with `n` elements.
304 ///
305 /// ```rust
306 /// use scirs2_core::builders::VectorBuilder;
307 ///
308 /// let v = VectorBuilder::<f64>::ones(4);
309 /// assert_eq!(v[2], 1.0);
310 /// ```
311 pub fn ones(n: usize) -> Array1<T> {
312 Array1::from_elem(n, T::one())
313 }
314}
315
316impl<T> VectorBuilder<T>
317where
318 T: Clone,
319{
320 /// Create a vector where element `i` is produced by `f(i)`.
321 ///
322 /// ```rust
323 /// use scirs2_core::builders::VectorBuilder;
324 ///
325 /// let squares = VectorBuilder::from_fn(5, |i| (i * i) as f64);
326 /// assert_eq!(squares[3], 9.0);
327 /// ```
328 pub fn from_fn<F>(n: usize, mut f: F) -> Array1<T>
329 where
330 F: FnMut(usize) -> T,
331 {
332 Array1::from_shape_fn(n, |i| f(i))
333 }
334
335 /// Create a vector filled with a constant value.
336 ///
337 /// ```rust
338 /// use scirs2_core::builders::VectorBuilder;
339 ///
340 /// let v = VectorBuilder::full(3, 7_i32);
341 /// assert_eq!(v[0], 7);
342 /// ```
343 pub fn full(n: usize, value: T) -> Array1<T> {
344 Array1::from_elem(n, value)
345 }
346}
347
348impl<T> VectorBuilder<T>
349where
350 T: Float + Display + Clone + MulAssign,
351{
352 /// Create `n` evenly spaced values from `start` to `stop` (inclusive).
353 ///
354 /// This is the analogue of NumPy's `np.linspace`.
355 ///
356 /// ```rust
357 /// use scirs2_core::builders::VectorBuilder;
358 ///
359 /// let v = VectorBuilder::<f64>::linspace(0.0, 4.0, 5);
360 /// assert!((v[0] - 0.0).abs() < 1e-12);
361 /// assert!((v[2] - 2.0).abs() < 1e-12);
362 /// assert!((v[4] - 4.0).abs() < 1e-12);
363 /// ```
364 pub fn linspace(start: T, stop: T, n: usize) -> Array1<T> {
365 if n == 0 {
366 return Array1::from(vec![]);
367 }
368 if n == 1 {
369 return Array1::from(vec![start]);
370 }
371 let steps = T::from(n - 1).unwrap_or_else(T::one);
372 Array1::from_shape_fn(n, |i| {
373 let t = T::from(i).unwrap_or_else(T::zero);
374 start + (stop - start) * (t / steps)
375 })
376 }
377
378 /// Create values from `start` up to (but not including) `stop` with step `step`.
379 ///
380 /// This is the analogue of NumPy's `np.arange`.
381 ///
382 /// ```rust
383 /// use scirs2_core::builders::VectorBuilder;
384 ///
385 /// let v = VectorBuilder::<f64>::arange(0.0, 5.0, 1.0);
386 /// assert_eq!(v.len(), 5);
387 /// assert!((v[0] - 0.0).abs() < 1e-12);
388 /// assert!((v[4] - 4.0).abs() < 1e-12);
389 ///
390 /// // Fractional step
391 /// let v2 = VectorBuilder::<f64>::arange(0.0, 1.0, 0.5);
392 /// assert_eq!(v2.len(), 2);
393 /// ```
394 pub fn arange(start: T, stop: T, step: T) -> Array1<T> {
395 if step == T::zero() || (stop - start).signum() != step.signum() {
396 return Array1::from(vec![]);
397 }
398 let n_float = ((stop - start) / step).ceil();
399 let n = n_float.to_usize().unwrap_or(0).max(0);
400 Array1::from_shape_fn(n, |i| start + step * T::from(i).unwrap_or_else(T::zero))
401 }
402
403 /// Create `n` values evenly spaced on a logarithmic scale.
404 ///
405 /// The values span from `10^start` to `10^stop` (inclusive), analogous to
406 /// NumPy's `np.logspace(start, stop, n, base=10)`.
407 ///
408 /// ```rust
409 /// use scirs2_core::builders::VectorBuilder;
410 ///
411 /// // 4 values from 10^0 = 1 to 10^3 = 1000
412 /// let v = VectorBuilder::<f64>::logspace(0.0, 3.0, 4);
413 /// assert!((v[0] - 1.0).abs() < 1e-10);
414 /// assert!((v[3] - 1000.0).abs() < 1e-8);
415 /// ```
416 pub fn logspace(start: T, stop: T, n: usize) -> Array1<T> {
417 let lin = Self::linspace(start, stop, n);
418 lin.mapv(|x| T::from(10.0_f64).unwrap_or_else(T::one).powf(x))
419 }
420
421 /// Create `n` uniform random values in `[0, 1)` using a seeded ChaCha8 RNG.
422 ///
423 /// ```rust
424 /// use scirs2_core::builders::VectorBuilder;
425 ///
426 /// let v = VectorBuilder::<f64>::rand(5, 42);
427 /// assert_eq!(v.len(), 5);
428 /// assert!(v.iter().all(|&x| x >= 0.0 && x < 1.0));
429 /// ```
430 pub fn rand(n: usize, seed: u64) -> Array1<T> {
431 use rand::SeedableRng;
432 use rand_chacha::ChaCha8Rng;
433
434 let mut rng = ChaCha8Rng::seed_from_u64(seed);
435 Array1::from_shape_fn(n, |_| {
436 use rand::Rng;
437 let v: f64 = rng.random();
438 T::from(v).unwrap_or_else(T::zero)
439 })
440 }
441
442 /// Create `n` standard-normal random values using a seeded ChaCha8 RNG.
443 ///
444 /// ```rust
445 /// use scirs2_core::builders::VectorBuilder;
446 ///
447 /// let v = VectorBuilder::<f64>::randn(5, 0);
448 /// assert_eq!(v.len(), 5);
449 /// ```
450 pub fn randn(n: usize, seed: u64) -> Array1<T> {
451 use rand::SeedableRng;
452 use rand_chacha::ChaCha8Rng;
453 use rand_distr::{Distribution, StandardNormal};
454
455 let mut rng = ChaCha8Rng::seed_from_u64(seed);
456 Array1::from_shape_fn(n, |_| {
457 let v: f64 = StandardNormal.sample(&mut rng);
458 T::from(v).unwrap_or_else(T::zero)
459 })
460 }
461}
462
463// ============================================================================
464// ArrayBuilder — Generic N-dimensional Array Construction
465// ============================================================================
466
467/// Generic builder for N-dimensional arrays.
468///
469/// Where `MatrixBuilder` targets exactly 2D and `VectorBuilder` targets exactly 1D,
470/// `ArrayBuilder` works with any [`ndarray::Dimension`] and is useful when the
471/// shape is determined at runtime.
472///
473/// # Examples
474///
475/// ```rust
476/// use scirs2_core::builders::ArrayBuilder;
477///
478/// let a2 = ArrayBuilder::<f64, _>::zeros(ndarray::Ix2(3, 4));
479/// assert_eq!(a2.shape(), &[3, 4]);
480///
481/// let a3 = ArrayBuilder::<f64, _>::zeros(ndarray::Ix3(2, 3, 4));
482/// assert_eq!(a3.shape(), &[2, 3, 4]);
483///
484/// // Dynamic dimension
485/// let ad = ArrayBuilder::<f64, ndarray::IxDyn>::zeros_dyn(&[2, 3, 4]);
486/// assert_eq!(ad.shape(), &[2, 3, 4]);
487/// ```
488pub struct ArrayBuilder<T, D>(std::marker::PhantomData<(T, D)>);
489
490impl<T, D> ArrayBuilder<T, D>
491where
492 T: Clone + Zero,
493 D: Dimension,
494{
495 /// Create a zeros array with the given shape.
496 ///
497 /// ```rust
498 /// use scirs2_core::builders::ArrayBuilder;
499 ///
500 /// let a = ArrayBuilder::<f64, _>::zeros(ndarray::Ix2(3, 4));
501 /// assert_eq!(a.shape(), &[3, 4]);
502 /// ```
503 pub fn zeros<Sh>(shape: Sh) -> ::ndarray::Array<T, D>
504 where
505 Sh: IntoDimension<Dim = D>,
506 {
507 ::ndarray::Array::zeros(shape)
508 }
509
510 /// Create an array filled with a constant value.
511 ///
512 /// ```rust
513 /// use scirs2_core::builders::ArrayBuilder;
514 ///
515 /// let a = ArrayBuilder::<i32, _>::full(ndarray::Ix2(2, 3), 7);
516 /// assert_eq!(a[[0, 0]], 7);
517 /// ```
518 pub fn full<Sh>(shape: Sh, value: T) -> ::ndarray::Array<T, D>
519 where
520 Sh: IntoDimension<Dim = D>,
521 {
522 ::ndarray::Array::from_elem(shape, value)
523 }
524
525 /// Create an array where each element is produced by a closure receiving the
526 /// dimension pattern (e.g. `(row, col)` for 2D, `(i, j, k)` for 3D, etc.).
527 ///
528 /// ```rust
529 /// use scirs2_core::builders::ArrayBuilder;
530 ///
531 /// // 3×3 matrix: element = row + col
532 /// let a = ArrayBuilder::<usize, ndarray::Ix2>::from_fn(
533 /// ndarray::Ix2(3, 3),
534 /// |(r, c)| r + c,
535 /// );
536 /// assert_eq!(a[[2, 2]], 4);
537 /// ```
538 pub fn from_fn<Sh, F>(shape: Sh, f: F) -> ::ndarray::Array<T, D>
539 where
540 Sh: IntoDimension<Dim = D>,
541 F: FnMut(D::Pattern) -> T,
542 {
543 ::ndarray::Array::from_shape_fn(shape, f)
544 }
545
546 /// Build an array from a flat `Vec` of elements in C-order (row-major).
547 ///
548 /// Returns a `CoreError` if the element count does not match the given shape.
549 ///
550 /// ```rust
551 /// use scirs2_core::builders::ArrayBuilder;
552 ///
553 /// let a = ArrayBuilder::<f64, ndarray::Ix2>::from_vec(
554 /// vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
555 /// ndarray::Ix2(2, 3),
556 /// ).expect("element count matches");
557 /// assert_eq!(a[[1, 2]], 6.0);
558 /// ```
559 pub fn from_vec<Sh>(data: Vec<T>, shape: Sh) -> CoreResult<::ndarray::Array<T, D>>
560 where
561 Sh: IntoDimension<Dim = D>,
562 {
563 ::ndarray::Array::from_shape_vec(shape, data).map_err(|e: ShapeError| {
564 CoreError::InvalidInput(ErrorContext::new(format!(
565 "ArrayBuilder::from_vec shape error: {e}"
566 )))
567 })
568 }
569}
570
571impl<T> ArrayBuilder<T, IxDyn>
572where
573 T: Clone + Zero,
574{
575 /// Create a dynamic-dimensional zeros array from a runtime shape slice.
576 ///
577 /// ```rust
578 /// use scirs2_core::builders::ArrayBuilder;
579 ///
580 /// let a = ArrayBuilder::<f64, ndarray::IxDyn>::zeros_dyn(&[2, 3, 4]);
581 /// assert_eq!(a.ndim(), 3);
582 /// assert_eq!(a.shape(), &[2, 3, 4]);
583 /// ```
584 pub fn zeros_dyn(shape: &[usize]) -> ArrayD<T> {
585 ArrayD::zeros(IxDyn(shape))
586 }
587
588 /// Create a dynamic-dimensional array filled with `value`.
589 pub fn full_dyn(shape: &[usize], value: T) -> ArrayD<T> {
590 ArrayD::from_elem(IxDyn(shape), value)
591 }
592}
593
594// ============================================================================
595// Tests
596// ============================================================================
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601 use approx::assert_abs_diff_eq;
602
603 // --- MatrixBuilder tests ---
604
605 #[test]
606 fn test_matrix_zeros() {
607 let m = MatrixBuilder::<f64>::zeros(3, 4);
608 assert_eq!(m.shape(), &[3, 4]);
609 assert!(m.iter().all(|&v| v == 0.0));
610 }
611
612 #[test]
613 fn test_matrix_ones() {
614 let m = MatrixBuilder::<f64>::ones(2, 5);
615 assert_eq!(m.shape(), &[2, 5]);
616 assert!(m.iter().all(|&v| v == 1.0));
617 }
618
619 #[test]
620 fn test_matrix_eye() {
621 let eye = MatrixBuilder::<f64>::eye(3);
622 assert_eq!(eye.shape(), &[3, 3]);
623 for i in 0..3 {
624 for j in 0..3 {
625 let expected = if i == j { 1.0 } else { 0.0 };
626 assert_abs_diff_eq!(eye[[i, j]], expected);
627 }
628 }
629 }
630
631 #[test]
632 fn test_matrix_from_vec() {
633 let m = MatrixBuilder::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0], 2, 2)
634 .expect("element count should match");
635 assert_eq!(m[[0, 0]], 1.0);
636 assert_eq!(m[[0, 1]], 2.0);
637 assert_eq!(m[[1, 0]], 3.0);
638 assert_eq!(m[[1, 1]], 4.0);
639 }
640
641 #[test]
642 fn test_matrix_from_vec_error() {
643 // Wrong element count → error
644 let result = MatrixBuilder::<f64>::from_vec(vec![1.0, 2.0, 3.0], 2, 2);
645 assert!(result.is_err());
646 }
647
648 #[test]
649 fn test_matrix_from_fn() {
650 let m = MatrixBuilder::from_fn(3, 3, |r, c| (r * 3 + c) as f64);
651 for r in 0..3 {
652 for c in 0..3 {
653 assert_abs_diff_eq!(m[[r, c]], (r * 3 + c) as f64);
654 }
655 }
656 }
657
658 #[test]
659 fn test_matrix_full() {
660 let m = MatrixBuilder::full(3, 3, 42_i32);
661 assert!(m.iter().all(|&v| v == 42));
662 }
663
664 #[test]
665 fn test_matrix_rand() {
666 let m = MatrixBuilder::<f64>::rand(10, 10, 99);
667 assert_eq!(m.shape(), &[10, 10]);
668 assert!(m.iter().all(|&v| v >= 0.0 && v < 1.0));
669 // Deterministic: same seed → same values
670 let m2 = MatrixBuilder::<f64>::rand(10, 10, 99);
671 assert_eq!(m, m2);
672 }
673
674 #[test]
675 fn test_matrix_randn() {
676 let m = MatrixBuilder::<f64>::randn(100, 100, 0);
677 // Mean should be roughly 0, std roughly 1
678 let mean = m.mean().expect("non-empty");
679 assert!(mean.abs() < 0.5, "mean={mean}");
680 }
681
682 // --- VectorBuilder tests ---
683
684 #[test]
685 fn test_vector_zeros() {
686 let v = VectorBuilder::<f64>::zeros(5);
687 assert_eq!(v.len(), 5);
688 assert!(v.iter().all(|&x| x == 0.0));
689 }
690
691 #[test]
692 fn test_vector_ones() {
693 let v = VectorBuilder::<f64>::ones(4);
694 assert_eq!(v.len(), 4);
695 assert!(v.iter().all(|&x| x == 1.0));
696 }
697
698 #[test]
699 fn test_vector_from_vec() {
700 let v = VectorBuilder::from_vec(vec![10.0_f64, 20.0, 30.0]);
701 assert_eq!(v.len(), 3);
702 assert_eq!(v[1], 20.0);
703 }
704
705 #[test]
706 fn test_vector_from_fn() {
707 let v = VectorBuilder::from_fn(5, |i| i as f64 * 2.0);
708 assert_abs_diff_eq!(v[3], 6.0);
709 }
710
711 #[test]
712 fn test_vector_full() {
713 let v = VectorBuilder::full(4, 3.14_f64);
714 assert!(v.iter().all(|&x| (x - 3.14).abs() < 1e-12));
715 }
716
717 #[test]
718 fn test_vector_linspace() {
719 let v = VectorBuilder::<f64>::linspace(0.0, 4.0, 5);
720 assert_eq!(v.len(), 5);
721 for (i, &val) in v.iter().enumerate() {
722 assert_abs_diff_eq!(val, i as f64, epsilon = 1e-12);
723 }
724 }
725
726 #[test]
727 fn test_vector_linspace_single() {
728 let v = VectorBuilder::<f64>::linspace(3.0, 3.0, 1);
729 assert_eq!(v.len(), 1);
730 assert_abs_diff_eq!(v[0], 3.0);
731 }
732
733 #[test]
734 fn test_vector_linspace_empty() {
735 let v = VectorBuilder::<f64>::linspace(0.0, 1.0, 0);
736 assert_eq!(v.len(), 0);
737 }
738
739 #[test]
740 fn test_vector_arange() {
741 let v = VectorBuilder::<f64>::arange(0.0, 5.0, 1.0);
742 assert_eq!(v.len(), 5);
743 for (i, &val) in v.iter().enumerate() {
744 assert_abs_diff_eq!(val, i as f64, epsilon = 1e-12);
745 }
746 }
747
748 #[test]
749 fn test_vector_arange_fractional() {
750 let v = VectorBuilder::<f64>::arange(0.0, 1.0, 0.5);
751 assert_eq!(v.len(), 2);
752 assert_abs_diff_eq!(v[0], 0.0, epsilon = 1e-12);
753 assert_abs_diff_eq!(v[1], 0.5, epsilon = 1e-12);
754 }
755
756 #[test]
757 fn test_vector_arange_empty() {
758 // step 0 → empty
759 let v = VectorBuilder::<f64>::arange(0.0, 5.0, 0.0);
760 assert_eq!(v.len(), 0);
761 // wrong direction → empty
762 let v2 = VectorBuilder::<f64>::arange(5.0, 0.0, 1.0);
763 assert_eq!(v2.len(), 0);
764 }
765
766 #[test]
767 fn test_vector_logspace() {
768 let v = VectorBuilder::<f64>::logspace(0.0, 3.0, 4);
769 assert_eq!(v.len(), 4);
770 assert_abs_diff_eq!(v[0], 1.0, epsilon = 1e-10);
771 assert_abs_diff_eq!(v[1], 10.0, epsilon = 1e-8);
772 assert_abs_diff_eq!(v[2], 100.0, epsilon = 1e-6);
773 assert_abs_diff_eq!(v[3], 1000.0, epsilon = 1e-4);
774 }
775
776 #[test]
777 fn test_vector_rand() {
778 let v = VectorBuilder::<f64>::rand(20, 7);
779 assert_eq!(v.len(), 20);
780 assert!(v.iter().all(|&x| x >= 0.0 && x < 1.0));
781 // Determinism
782 let v2 = VectorBuilder::<f64>::rand(20, 7);
783 assert_eq!(v, v2);
784 }
785
786 #[test]
787 fn test_vector_randn() {
788 let v = VectorBuilder::<f64>::randn(1000, 123);
789 assert_eq!(v.len(), 1000);
790 let mean = v.mean().expect("non-empty");
791 assert!(mean.abs() < 0.2, "mean={mean}");
792 }
793
794 // --- ArrayBuilder tests ---
795
796 #[test]
797 fn test_array_builder_zeros_2d() {
798 let a = ArrayBuilder::<f64, ::ndarray::Ix2>::zeros(::ndarray::Ix2(3, 4));
799 assert_eq!(a.shape(), &[3, 4]);
800 assert!(a.iter().all(|&v| v == 0.0));
801 }
802
803 #[test]
804 fn test_array_builder_zeros_3d() {
805 let a = ArrayBuilder::<f64, ::ndarray::Ix3>::zeros(::ndarray::Ix3(2, 3, 4));
806 assert_eq!(a.shape(), &[2, 3, 4]);
807 }
808
809 #[test]
810 fn test_array_builder_zeros_dyn() {
811 let a = ArrayBuilder::<f64, ::ndarray::IxDyn>::zeros_dyn(&[2, 3, 4]);
812 assert_eq!(a.ndim(), 3);
813 assert_eq!(a.shape(), &[2, 3, 4]);
814 }
815
816 #[test]
817 fn test_array_builder_full() {
818 let a = ArrayBuilder::<i32, ::ndarray::Ix2>::full(::ndarray::Ix2(3, 3), 7);
819 assert!(a.iter().all(|&v| v == 7));
820 }
821
822 #[test]
823 fn test_array_builder_from_vec_ok() {
824 let a = ArrayBuilder::<f64, ::ndarray::Ix2>::from_vec(
825 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
826 ::ndarray::Ix2(2, 3),
827 )
828 .expect("valid shape");
829 assert_eq!(a[[1, 2]], 6.0);
830 }
831
832 #[test]
833 fn test_array_builder_from_vec_err() {
834 let result = ArrayBuilder::<f64, ::ndarray::Ix2>::from_vec(
835 vec![1.0, 2.0, 3.0],
836 ::ndarray::Ix2(2, 3), // needs 6 elements
837 );
838 assert!(result.is_err());
839 }
840}