Skip to main content

sklears_simd/
fluent.rs

1//! Fluent API for SIMD operations
2//!
3//! This module provides a chainable, fluent interface for composing complex
4//! SIMD operations in a readable and efficient manner.
5
6#[cfg(feature = "no-std")]
7use alloc::format;
8#[cfg(feature = "no-std")]
9use alloc::string::String;
10#[cfg(feature = "no-std")]
11use alloc::vec;
12#[cfg(feature = "no-std")]
13use alloc::vec::Vec;
14
15// Note: mem import removed as it's no longer needed after proper transpose implementation
16
17use crate::activation;
18use crate::allocator::SimdVec;
19use crate::distance;
20use crate::kernels;
21use crate::loss;
22use crate::safety::SafeSimdOps;
23use crate::vector;
24
25/// Fluent builder for vector operations
26#[derive(Debug, Clone)]
27pub struct VectorBuilder {
28    data: Vec<f32>,
29    safe_mode: bool,
30}
31
32impl VectorBuilder {
33    /// Create a new vector builder
34    pub fn new() -> Self {
35        Self {
36            data: Vec::new(),
37            safe_mode: false,
38        }
39    }
40
41    /// Create a vector builder from existing data
42    pub fn from_slice(data: &[f32]) -> Self {
43        Self {
44            data: data.to_vec(),
45            safe_mode: false,
46        }
47    }
48
49    /// Create a vector builder with SIMD-aligned storage
50    pub fn with_simd_storage(capacity: usize) -> Self {
51        let simd_vec = SimdVec::with_capacity(capacity);
52        Self {
53            data: simd_vec.as_slice().to_vec(),
54            safe_mode: false,
55        }
56    }
57
58    /// Enable safe mode with bounds checking and overflow detection
59    pub fn safe(mut self) -> Self {
60        self.safe_mode = true;
61        self
62    }
63
64    /// Add elements to the vector
65    pub fn push(mut self, value: f32) -> Self {
66        self.data.push(value);
67        self
68    }
69
70    /// Add multiple elements to the vector
71    pub fn extend(mut self, values: &[f32]) -> Self {
72        self.data.extend_from_slice(values);
73        self
74    }
75
76    /// Fill the vector with a value
77    pub fn fill(mut self, size: usize, value: f32) -> Self {
78        self.data = vec![value; size];
79        self
80    }
81
82    /// Create a range of values
83    pub fn range(mut self, start: f32, end: f32, step: f32) -> Self {
84        let mut current = start;
85        self.data.clear();
86        while current < end {
87            self.data.push(current);
88            current += step;
89        }
90        self
91    }
92
93    /// Create a linearly spaced vector
94    pub fn linspace(mut self, start: f32, end: f32, num: usize) -> Self {
95        if num == 0 {
96            self.data.clear();
97            return self;
98        }
99
100        if num == 1 {
101            self.data = vec![start];
102            return self;
103        }
104
105        let step = (end - start) / (num - 1) as f32;
106        self.data = (0..num).map(|i| start + (i as f32) * step).collect();
107        self
108    }
109
110    /// Scale all elements by a factor
111    pub fn scale(mut self, factor: f32) -> Self {
112        if self.safe_mode {
113            for value in &mut self.data {
114                *value = SafeSimdOps::safe_mul_f32(*value, factor).unwrap_or(0.0);
115            }
116        } else {
117            vector::scale(&mut self.data, factor);
118        }
119        self
120    }
121
122    /// Add a scalar to all elements
123    pub fn add_scalar(mut self, value: f32) -> Self {
124        if self.safe_mode {
125            for element in &mut self.data {
126                *element = SafeSimdOps::safe_add_f32(*element, value).unwrap_or(0.0);
127            }
128        } else {
129            for element in &mut self.data {
130                *element += value;
131            }
132        }
133        self
134    }
135
136    /// Apply element-wise operation
137    pub fn map<F>(mut self, f: F) -> Self
138    where
139        F: Fn(f32) -> f32,
140    {
141        for element in &mut self.data {
142            *element = f(*element);
143        }
144        self
145    }
146
147    /// Normalize the vector
148    pub fn normalize(mut self) -> Self {
149        if self.safe_mode {
150            self.data = SafeSimdOps::safe_normalize_f32(&self.data).unwrap_or_default();
151        } else {
152            let norm = vector::norm_l2(&self.data);
153            if norm > 0.0 {
154                vector::scale(&mut self.data, 1.0 / norm);
155            }
156        }
157        self
158    }
159
160    /// Calculate dot product with another vector
161    pub fn dot(&self, other: &[f32]) -> f32 {
162        if self.safe_mode {
163            SafeSimdOps::safe_dot_product_f32(&self.data, other).unwrap_or(0.0)
164        } else {
165            vector::dot_product(&self.data, other)
166        }
167    }
168
169    /// Calculate distance to another vector
170    pub fn distance_to(&self, other: &[f32], metric: DistanceMetric) -> f32 {
171        match metric {
172            DistanceMetric::Euclidean => distance::euclidean_distance(&self.data, other),
173            DistanceMetric::Manhattan => distance::manhattan_distance(&self.data, other),
174            DistanceMetric::Cosine => distance::cosine_distance(&self.data, other),
175            DistanceMetric::Chebyshev => distance::chebyshev_distance(&self.data, other),
176        }
177    }
178
179    /// Apply activation function
180    pub fn activate(mut self, activation: ActivationFunction) -> Self {
181        let mut output = vec![0.0; self.data.len()];
182        match activation {
183            ActivationFunction::Sigmoid => activation::sigmoid(&self.data, &mut output),
184            ActivationFunction::Relu => activation::relu(&self.data, &mut output),
185            ActivationFunction::Tanh => activation::tanh_activation(&self.data, &mut output),
186            ActivationFunction::Softmax => activation::softmax(&self.data, &mut output),
187        }
188        self.data = output;
189        self
190    }
191
192    /// Get statistics about the vector
193    pub fn stats(&self) -> VectorStats {
194        let (min, max) = vector::min_max(&self.data);
195        VectorStats {
196            mean: vector::mean(&self.data),
197            variance: vector::variance(&self.data),
198            min,
199            max,
200            norm: vector::norm_l2(&self.data),
201            length: self.data.len(),
202        }
203    }
204
205    /// Build the final vector
206    pub fn build(self) -> Vec<f32> {
207        self.data
208    }
209
210    /// Build into a SIMD-aligned vector
211    pub fn build_simd(self) -> SimdVec<f32> {
212        let mut simd_vec = SimdVec::new();
213        for value in self.data {
214            simd_vec.push(value);
215        }
216        simd_vec
217    }
218
219    /// Get a reference to the underlying data
220    pub fn as_slice(&self) -> &[f32] {
221        &self.data
222    }
223}
224
225impl Default for VectorBuilder {
226    fn default() -> Self {
227        Self::new()
228    }
229}
230
231/// Distance metrics for vector operations
232#[derive(Debug, Clone, Copy)]
233pub enum DistanceMetric {
234    Euclidean,
235    Manhattan,
236    Cosine,
237    Chebyshev,
238}
239
240/// Activation functions for neural network operations
241#[derive(Debug, Clone, Copy)]
242pub enum ActivationFunction {
243    Sigmoid,
244    Relu,
245    Tanh,
246    Softmax,
247}
248
249/// Statistics about a vector
250#[derive(Debug, Clone)]
251pub struct VectorStats {
252    pub mean: f32,
253    pub variance: f32,
254    pub min: f32,
255    pub max: f32,
256    pub norm: f32,
257    pub length: usize,
258}
259
260/// Fluent builder for matrix operations
261#[derive(Debug, Clone)]
262pub struct MatrixBuilder {
263    data: Vec<f32>,
264    rows: usize,
265    cols: usize,
266    safe_mode: bool,
267}
268
269impl MatrixBuilder {
270    /// Create a new matrix builder
271    pub fn new(rows: usize, cols: usize) -> Self {
272        Self {
273            data: vec![0.0; rows * cols],
274            rows,
275            cols,
276            safe_mode: false,
277        }
278    }
279
280    /// Create from existing data
281    pub fn from_data(data: Vec<f32>, rows: usize, cols: usize) -> Result<Self, String> {
282        if data.len() != rows * cols {
283            return Err(format!(
284                "Data length {} doesn't match matrix dimensions {}x{}",
285                data.len(),
286                rows,
287                cols
288            ));
289        }
290
291        Ok(Self {
292            data,
293            rows,
294            cols,
295            safe_mode: false,
296        })
297    }
298
299    /// Create an identity matrix
300    pub fn identity(size: usize) -> Self {
301        let mut data = vec![0.0; size * size];
302        for i in 0..size {
303            data[i * size + i] = 1.0;
304        }
305
306        Self {
307            data,
308            rows: size,
309            cols: size,
310            safe_mode: false,
311        }
312    }
313
314    /// Create a matrix filled with random values
315    pub fn random(rows: usize, cols: usize, min: f32, max: f32) -> Self {
316        use scirs2_core::random::thread_rng;
317        let mut rng = thread_rng();
318        let data: Vec<f32> = (0..rows * cols)
319            .map(|_| {
320                let val: f32 = rng.random::<f32>();
321                min + val * (max - min)
322            })
323            .collect();
324
325        Self {
326            data,
327            rows,
328            cols,
329            safe_mode: false,
330        }
331    }
332
333    /// Enable safe mode
334    pub fn safe(mut self) -> Self {
335        self.safe_mode = true;
336        self
337    }
338
339    /// Set a value at position (row, col)
340    pub fn set(mut self, row: usize, col: usize, value: f32) -> Result<Self, String> {
341        if row >= self.rows || col >= self.cols {
342            return Err(format!(
343                "Index ({}, {}) out of bounds for {}x{} matrix",
344                row, col, self.rows, self.cols
345            ));
346        }
347
348        self.data[row * self.cols + col] = value;
349        Ok(self)
350    }
351
352    /// Fill the matrix with a value
353    pub fn fill(mut self, value: f32) -> Self {
354        self.data.fill(value);
355        self
356    }
357
358    /// Scale all elements
359    pub fn scale(mut self, factor: f32) -> Self {
360        if self.safe_mode {
361            for value in &mut self.data {
362                *value = SafeSimdOps::safe_mul_f32(*value, factor).unwrap_or(0.0);
363            }
364        } else {
365            vector::scale(&mut self.data, factor);
366        }
367        self
368    }
369
370    /// Transpose the matrix
371    ///
372    /// Converts an MxN matrix to an NxM matrix by swapping rows and columns.
373    /// Uses row-major to row-major transposition for cache efficiency.
374    pub fn transpose(self) -> Self {
375        let mut transposed_data = vec![0.0; self.data.len()];
376
377        // Transpose: transposed[col][row] = original[row][col]
378        for row in 0..self.rows {
379            for col in 0..self.cols {
380                let src_idx = row * self.cols + col;
381                let dst_idx = col * self.rows + row;
382                transposed_data[dst_idx] = self.data[src_idx];
383            }
384        }
385
386        Self {
387            data: transposed_data,
388            rows: self.cols,
389            cols: self.rows,
390            safe_mode: self.safe_mode,
391        }
392    }
393
394    /// Multiply by another matrix
395    ///
396    /// Computes C = A × B where A is self (MxN), B is other (NxP), and C is (MxP).
397    /// Uses standard O(M×N×P) algorithm with row-major order optimization.
398    pub fn multiply(&self, other: &MatrixBuilder) -> Result<MatrixBuilder, String> {
399        if self.cols != other.rows {
400            return Err(format!(
401                "Cannot multiply {}x{} matrix by {}x{} matrix",
402                self.rows, self.cols, other.rows, other.cols
403            ));
404        }
405
406        let m = self.rows;
407        let n = self.cols;
408        let p = other.cols;
409        let mut result_data = vec![0.0; m * p];
410
411        // C[i][j] = sum_k A[i][k] * B[k][j]
412        for i in 0..m {
413            for j in 0..p {
414                let mut sum = 0.0;
415                for k in 0..n {
416                    let a_val = self.data[i * self.cols + k];
417                    let b_val = other.data[k * other.cols + j];
418
419                    if self.safe_mode || other.safe_mode {
420                        sum += SafeSimdOps::safe_mul_f32(a_val, b_val).unwrap_or(0.0);
421                    } else {
422                        sum += a_val * b_val;
423                    }
424                }
425                result_data[i * p + j] = sum;
426            }
427        }
428
429        Ok(MatrixBuilder {
430            data: result_data,
431            rows: m,
432            cols: p,
433            safe_mode: self.safe_mode || other.safe_mode,
434        })
435    }
436
437    /// Multiply by a vector
438    ///
439    /// Computes y = A × x where A is self (MxN), x is vector (N), and y is (M).
440    /// Uses row-major order for cache-efficient access.
441    pub fn multiply_vector(&self, vector: &[f32]) -> Result<Vec<f32>, String> {
442        if self.cols != vector.len() {
443            return Err(format!(
444                "Cannot multiply {}x{} matrix by vector of length {}",
445                self.rows,
446                self.cols,
447                vector.len()
448            ));
449        }
450
451        let mut result = vec![0.0; self.rows];
452
453        // y[i] = sum_j A[i][j] * x[j]
454        for (i, res) in result.iter_mut().enumerate() {
455            let row_start = i * self.cols;
456            let mut sum = 0.0;
457            for (j, &x_val) in vector.iter().enumerate().take(self.cols) {
458                let a_val = self.data[row_start + j];
459
460                if self.safe_mode {
461                    sum += SafeSimdOps::safe_mul_f32(a_val, x_val).unwrap_or(0.0);
462                } else {
463                    sum += a_val * x_val;
464                }
465            }
466            *res = sum;
467        }
468
469        Ok(result)
470    }
471
472    /// Get matrix dimensions
473    pub fn dimensions(&self) -> (usize, usize) {
474        (self.rows, self.cols)
475    }
476
477    /// Build the final matrix
478    pub fn build(self) -> (Vec<f32>, usize, usize) {
479        (self.data, self.rows, self.cols)
480    }
481
482    /// Get a reference to the underlying data
483    pub fn as_slice(&self) -> &[f32] {
484        &self.data
485    }
486}
487
488/// Fluent builder for machine learning operations
489#[derive(Debug)]
490pub struct MLBuilder {
491    features: Vec<f32>,
492    targets: Vec<f32>,
493    safe_mode: bool,
494}
495
496impl MLBuilder {
497    /// Create a new ML builder
498    pub fn new() -> Self {
499        Self {
500            features: Vec::new(),
501            targets: Vec::new(),
502            safe_mode: false,
503        }
504    }
505
506    /// Set feature data
507    pub fn features(mut self, features: Vec<f32>) -> Self {
508        self.features = features;
509        self
510    }
511
512    /// Set target data
513    pub fn targets(mut self, targets: Vec<f32>) -> Self {
514        self.targets = targets;
515        self
516    }
517
518    /// Enable safe mode
519    pub fn safe(mut self) -> Self {
520        self.safe_mode = true;
521        self
522    }
523
524    /// Calculate loss using specified function
525    pub fn loss(&self, loss_type: LossFunction) -> f32 {
526        match loss_type {
527            LossFunction::MSE => loss::mse_loss(&self.features, &self.targets),
528            LossFunction::MAE => loss::mae_loss(&self.features, &self.targets),
529            LossFunction::Huber(delta) => loss::huber_loss(&self.features, &self.targets, delta),
530        }
531    }
532
533    /// Calculate gradients
534    pub fn gradients(&self, loss_type: LossFunction) -> Vec<f32> {
535        let mut output = vec![0.0; self.features.len()];
536        match loss_type {
537            LossFunction::MSE => loss::mse_gradient(&self.features, &self.targets, &mut output),
538            LossFunction::MAE => loss::mae_gradient(&self.features, &self.targets, &mut output),
539            LossFunction::Huber(delta) => {
540                loss::huber_gradient(&self.features, &self.targets, delta, &mut output)
541            }
542        }
543        output
544    }
545
546    /// Compute kernel values
547    pub fn kernel(&self, other: &[f32], kernel_type: KernelType) -> f32 {
548        match kernel_type {
549            KernelType::Linear => kernels::linear_kernel(&self.features, other),
550            KernelType::RBF(gamma) => kernels::rbf_kernel(&self.features, other, gamma),
551            KernelType::Polynomial(degree, coef) => {
552                kernels::polynomial_kernel(&self.features, other, degree, coef, 1.0)
553            }
554        }
555    }
556}
557
558impl Default for MLBuilder {
559    fn default() -> Self {
560        Self::new()
561    }
562}
563
564/// Loss function types
565#[derive(Debug, Clone, Copy)]
566pub enum LossFunction {
567    MSE,
568    MAE,
569    Huber(f32),
570}
571
572/// Kernel function types
573#[derive(Debug, Clone, Copy)]
574pub enum KernelType {
575    Linear,
576    RBF(f32),
577    Polynomial(f32, f32),
578}
579
580/// Convenience functions for common operations
581pub mod ops {
582    use super::*;
583
584    /// Create a vector with fluent API
585    pub fn vector() -> VectorBuilder {
586        VectorBuilder::new()
587    }
588
589    /// Create a matrix with fluent API
590    pub fn matrix(rows: usize, cols: usize) -> MatrixBuilder {
591        MatrixBuilder::new(rows, cols)
592    }
593
594    /// Create an ML builder
595    pub fn ml() -> MLBuilder {
596        MLBuilder::new()
597    }
598
599    /// Quick vector operations
600    pub fn quick_dot(a: &[f32], b: &[f32]) -> f32 {
601        VectorBuilder::from_slice(a).dot(b)
602    }
603
604    /// Quick distance calculation
605    pub fn quick_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
606        VectorBuilder::from_slice(a).distance_to(b, metric)
607    }
608
609    /// Quick normalization
610    pub fn quick_normalize(data: &[f32]) -> Vec<f32> {
611        VectorBuilder::from_slice(data).normalize().build()
612    }
613}
614
615#[allow(non_snake_case)]
616#[cfg(all(test, not(feature = "no-std")))]
617mod tests {
618    use super::ops::*;
619    use super::*;
620
621    #[cfg(feature = "no-std")]
622    use alloc::{vec, vec::Vec};
623
624    #[test]
625    fn test_vector_builder_basic() {
626        let vec = vector().push(1.0).push(2.0).push(3.0).scale(2.0).build();
627
628        assert_eq!(vec, [2.0, 4.0, 6.0]);
629    }
630
631    #[test]
632    fn test_vector_builder_chaining() {
633        let vec = vector()
634            .linspace(0.0, 10.0, 11)
635            .scale(0.1)
636            .add_scalar(1.0)
637            .normalize()
638            .build();
639
640        assert_eq!(vec.len(), 11);
641        let norm = VectorBuilder::from_slice(&vec).stats().norm;
642        assert!((norm - 1.0).abs() < 1e-6);
643    }
644
645    #[test]
646    fn test_vector_builder_stats() {
647        let stats = vector().range(1.0, 6.0, 1.0).stats();
648
649        assert_eq!(stats.length, 5);
650        assert_eq!(stats.mean, 3.0);
651        assert_eq!(stats.min, 1.0);
652        assert_eq!(stats.max, 5.0);
653    }
654
655    #[test]
656    fn test_vector_builder_distance() {
657        let vec1 = vector().range(0.0, 3.0, 1.0).build();
658        let vec2 = vector().range(1.0, 4.0, 1.0).build();
659
660        let distance =
661            VectorBuilder::from_slice(&vec1).distance_to(&vec2, DistanceMetric::Euclidean);
662
663        assert!((distance - (3.0_f32).sqrt()).abs() < 1e-6);
664    }
665
666    #[test]
667    fn test_matrix_builder_basic() {
668        let matrix = matrix(2, 2).fill(1.0).scale(2.0).build();
669
670        assert_eq!(matrix.0, [2.0, 2.0, 2.0, 2.0]);
671        assert_eq!(matrix.1, 2);
672        assert_eq!(matrix.2, 2);
673    }
674
675    #[test]
676    fn test_matrix_builder_identity() {
677        let identity = MatrixBuilder::identity(3).build();
678        let expected = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
679
680        assert_eq!(identity.0, expected);
681    }
682
683    #[test]
684    fn test_matrix_multiplication() {
685        // Test basic 2x2 matrix multiplication
686        // A = [[1, 2], [3, 4]]
687        // B = [[2, 0], [1, 2]]
688        // A × B = [[1*2+2*1, 1*0+2*2], [3*2+4*1, 3*0+4*2]]
689        //       = [[4, 4], [10, 8]]
690        let a = MatrixBuilder::from_data(vec![1.0, 2.0, 3.0, 4.0], 2, 2)
691            .expect("operation should succeed");
692        let b = MatrixBuilder::from_data(vec![2.0, 0.0, 1.0, 2.0], 2, 2)
693            .expect("operation should succeed");
694
695        let result = a.multiply(&b).expect("operation should succeed").build();
696        assert_eq!(result.0, [4.0, 4.0, 10.0, 8.0]);
697        assert_eq!(result.1, 2); // rows
698        assert_eq!(result.2, 2); // cols
699    }
700
701    #[test]
702    fn test_matrix_multiplication_rectangular() {
703        // Test non-square matrix multiplication
704        // A = [[1, 2, 3]] (1x3)
705        // B = [[4], [5], [6]] (3x1)
706        // A × B = [[1*4 + 2*5 + 3*6]] = [[32]] (1x1)
707        let a =
708            MatrixBuilder::from_data(vec![1.0, 2.0, 3.0], 1, 3).expect("operation should succeed");
709        let b =
710            MatrixBuilder::from_data(vec![4.0, 5.0, 6.0], 3, 1).expect("operation should succeed");
711
712        let result = a.multiply(&b).expect("operation should succeed").build();
713        assert_eq!(result.0, [32.0]);
714        assert_eq!(result.1, 1);
715        assert_eq!(result.2, 1);
716    }
717
718    #[test]
719    fn test_matrix_multiplication_identity() {
720        // Test multiplication by identity matrix
721        let a = MatrixBuilder::from_data(vec![1.0, 2.0, 3.0, 4.0], 2, 2)
722            .expect("operation should succeed");
723        let identity = MatrixBuilder::identity(2);
724
725        let result = a
726            .multiply(&identity)
727            .expect("operation should succeed")
728            .build();
729        assert_eq!(result.0, [1.0, 2.0, 3.0, 4.0]);
730    }
731
732    #[test]
733    fn test_matrix_transpose() {
734        // Test basic transpose
735        // A = [[1, 2, 3], [4, 5, 6]] (2x3)
736        // A^T = [[1, 4], [2, 5], [3, 6]] (3x2)
737        let a = MatrixBuilder::from_data(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3)
738            .expect("operation should succeed");
739
740        let transposed = a.transpose().build();
741        assert_eq!(transposed.0, [1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
742        assert_eq!(transposed.1, 3); // rows
743        assert_eq!(transposed.2, 2); // cols
744    }
745
746    #[test]
747    fn test_matrix_transpose_square() {
748        // Test square matrix transpose
749        // A = [[1, 2], [3, 4]] (2x2)
750        // A^T = [[1, 3], [2, 4]] (2x2)
751        let a = MatrixBuilder::from_data(vec![1.0, 2.0, 3.0, 4.0], 2, 2)
752            .expect("operation should succeed");
753
754        let transposed = a.transpose().build();
755        assert_eq!(transposed.0, [1.0, 3.0, 2.0, 4.0]);
756        assert_eq!(transposed.1, 2);
757        assert_eq!(transposed.2, 2);
758    }
759
760    #[test]
761    fn test_matrix_transpose_double() {
762        // Test that (A^T)^T = A
763        let a = MatrixBuilder::from_data(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3)
764            .expect("operation should succeed");
765        let original_data = a.data.clone();
766
767        let double_transposed = a.transpose().transpose().build();
768        assert_eq!(double_transposed.0, original_data);
769        assert_eq!(double_transposed.1, 2);
770        assert_eq!(double_transposed.2, 3);
771    }
772
773    #[test]
774    fn test_matrix_vector_multiplication() {
775        // Test basic matrix-vector multiplication
776        // A = [[1, 2], [3, 4]] (2x2)
777        // x = [5, 6]
778        // A × x = [1*5+2*6, 3*5+4*6] = [17, 39]
779        let a = MatrixBuilder::from_data(vec![1.0, 2.0, 3.0, 4.0], 2, 2)
780            .expect("operation should succeed");
781        let x = vec![5.0, 6.0];
782
783        let result = a.multiply_vector(&x).expect("operation should succeed");
784        assert_eq!(result, [17.0, 39.0]);
785    }
786
787    #[test]
788    fn test_matrix_vector_multiplication_rectangular() {
789        // Test rectangular matrix-vector multiplication
790        // A = [[1, 2, 3], [4, 5, 6]] (2x3)
791        // x = [1, 2, 3]
792        // A × x = [1*1+2*2+3*3, 4*1+5*2+6*3] = [14, 32]
793        let a = MatrixBuilder::from_data(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3)
794            .expect("operation should succeed");
795        let x = vec![1.0, 2.0, 3.0];
796
797        let result = a.multiply_vector(&x).expect("operation should succeed");
798        assert_eq!(result, [14.0, 32.0]);
799    }
800
801    #[test]
802    fn test_matrix_vector_multiplication_identity() {
803        // Test that identity matrix times vector equals vector
804        let identity = MatrixBuilder::identity(3);
805        let x = vec![5.0, 6.0, 7.0];
806
807        let result = identity
808            .multiply_vector(&x)
809            .expect("operation should succeed");
810        assert_eq!(result, x);
811    }
812
813    #[test]
814    fn test_matrix_operations_dimension_checks() {
815        // Test dimension validation
816        let a = MatrixBuilder::from_data(vec![1.0, 2.0, 3.0, 4.0], 2, 2)
817            .expect("operation should succeed");
818        let b =
819            MatrixBuilder::from_data(vec![1.0, 2.0, 3.0], 1, 3).expect("operation should succeed");
820
821        // Should fail: 2x2 × 1x3 (incompatible dimensions)
822        assert!(a.multiply(&b).is_err());
823
824        // Should fail: 2x2 matrix with 3-element vector
825        assert!(a.multiply_vector(&[1.0, 2.0, 3.0]).is_err());
826    }
827
828    #[test]
829    fn test_ml_builder() {
830        let predictions = vec![1.0, 2.0, 3.0];
831        let targets = vec![1.5, 1.8, 2.7];
832
833        let mse = ml()
834            .features(predictions.clone())
835            .targets(targets.clone())
836            .loss(LossFunction::MSE);
837
838        assert!(mse > 0.0);
839
840        let gradients = ml()
841            .features(predictions)
842            .targets(targets)
843            .gradients(LossFunction::MSE);
844
845        assert_eq!(gradients.len(), 3);
846    }
847
848    #[test]
849    fn test_quick_operations() {
850        let a = [1.0, 2.0, 3.0];
851        let b = [4.0, 5.0, 6.0];
852
853        let dot = quick_dot(&a, &b);
854        assert_eq!(dot, 32.0);
855
856        let distance = quick_distance(&a, &b, DistanceMetric::Euclidean);
857        assert!((distance - (27.0_f32).sqrt()).abs() < 1e-6);
858
859        let normalized = quick_normalize(&a);
860        let norm = VectorBuilder::from_slice(&normalized).stats().norm;
861        assert!((norm - 1.0).abs() < 1e-6);
862    }
863
864    #[test]
865    fn test_safe_mode() {
866        let vec = vector()
867            .safe()
868            .push(1.0)
869            .push(2.0)
870            .scale(f32::MAX) // This would overflow without safe mode
871            .build();
872
873        // In safe mode, overflow should be handled gracefully
874        assert!(vec.iter().all(|&x| x.is_finite()));
875    }
876
877    #[test]
878    fn test_map_operation() {
879        let vec = vector()
880            .range(1.0, 4.0, 1.0)
881            .map(|x| x * x) // Square each element
882            .build();
883
884        assert_eq!(vec, [1.0, 4.0, 9.0]);
885    }
886}