sublinear_solver/
math_wasm.rs

1use std::fmt;
2use rand::Rng;
3
4#[derive(Debug, Clone)]
5pub struct Matrix {
6    data: Vec<f64>,
7    rows: usize,
8    cols: usize,
9}
10
11impl Matrix {
12    pub fn new(rows: usize, cols: usize) -> Self {
13        Self {
14            data: vec![0.0; rows * cols],
15            rows,
16            cols,
17        }
18    }
19
20    pub fn from_slice(data: &[f64], rows: usize, cols: usize) -> Self {
21        assert_eq!(data.len(), rows * cols, "Data length must match matrix dimensions");
22        Self {
23            data: data.to_vec(),
24            rows,
25            cols,
26        }
27    }
28
29    pub fn identity(size: usize) -> Self {
30        let mut matrix = Self::new(size, size);
31        for i in 0..size {
32            matrix.data[i * size + i] = 1.0;
33        }
34        matrix
35    }
36
37    pub fn random(rows: usize, cols: usize) -> Self {
38        let mut matrix = Self::new(rows, cols);
39        for i in 0..matrix.data.len() {
40            #[cfg(feature = "wasm")]
41            {
42                matrix.data[i] = fastrand::f64();
43            }
44            #[cfg(not(feature = "wasm"))]
45            {
46                matrix.data[i] = rand::random::<f64>();
47            }
48        }
49        matrix
50    }
51
52    pub fn rows(&self) -> usize {
53        self.rows
54    }
55
56    pub fn cols(&self) -> usize {
57        self.cols
58    }
59
60    pub fn data(&self) -> &[f64] {
61        &self.data
62    }
63
64    pub fn data_mut(&mut self) -> &mut [f64] {
65        &mut self.data
66    }
67
68    pub fn get(&self, row: usize, col: usize) -> f64 {
69        assert!(row < self.rows && col < self.cols, "Index out of bounds");
70        self.data[row * self.cols + col]
71    }
72
73    pub fn set(&mut self, row: usize, col: usize, value: f64) {
74        assert!(row < self.rows && col < self.cols, "Index out of bounds");
75        self.data[row * self.cols + col] = value;
76    }
77
78    pub fn multiply(&self, other: &Matrix) -> Result<Matrix, String> {
79        if self.cols != other.rows {
80            return Err("Matrix dimensions incompatible for multiplication".to_string());
81        }
82
83        let mut result = Matrix::new(self.rows, other.cols);
84
85        for i in 0..self.rows {
86            for j in 0..other.cols {
87                let mut sum = 0.0;
88                for k in 0..self.cols {
89                    sum += self.get(i, k) * other.get(k, j);
90                }
91                result.set(i, j, sum);
92            }
93        }
94
95        Ok(result)
96    }
97
98    pub fn transpose(&self) -> Matrix {
99        let mut result = Matrix::new(self.cols, self.rows);
100        for i in 0..self.rows {
101            for j in 0..self.cols {
102                result.set(j, i, self.get(i, j));
103            }
104        }
105        result
106    }
107
108    pub fn is_symmetric(&self) -> bool {
109        if self.rows != self.cols {
110            return false;
111        }
112
113        for i in 0..self.rows {
114            for j in 0..self.cols {
115                if (self.get(i, j) - self.get(j, i)).abs() > 1e-10 {
116                    return false;
117                }
118            }
119        }
120        true
121    }
122
123    pub fn is_positive_definite(&self) -> bool {
124        if !self.is_symmetric() {
125            return false;
126        }
127
128        // Simple check using Sylvester's criterion for small matrices
129        // For larger matrices, this should use Cholesky decomposition
130        if self.rows <= 3 {
131            return self.check_sylvester_criterion();
132        }
133
134        // For larger matrices, approximate check
135        true
136    }
137
138    fn check_sylvester_criterion(&self) -> bool {
139        for k in 1..=self.rows {
140            let det = self.leading_principal_minor(k);
141            if det <= 0.0 {
142                return false;
143            }
144        }
145        true
146    }
147
148    fn leading_principal_minor(&self, k: usize) -> f64 {
149        if k == 1 {
150            return self.get(0, 0);
151        }
152        if k == 2 {
153            return self.get(0, 0) * self.get(1, 1) - self.get(0, 1) * self.get(1, 0);
154        }
155        if k == 3 {
156            let a = self.get(0, 0);
157            let b = self.get(0, 1);
158            let c = self.get(0, 2);
159            let d = self.get(1, 0);
160            let e = self.get(1, 1);
161            let f = self.get(1, 2);
162            let g = self.get(2, 0);
163            let h = self.get(2, 1);
164            let i = self.get(2, 2);
165
166            return a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g);
167        }
168
169        // For larger matrices, use a simplified approximation
170        1.0
171    }
172}
173
174impl fmt::Display for Matrix {
175    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176        for i in 0..self.rows {
177            write!(f, "[")?;
178            for j in 0..self.cols {
179                if j > 0 {
180                    write!(f, ", ")?;
181                }
182                write!(f, "{:8.4}", self.get(i, j))?;
183            }
184            writeln!(f, "]")?;
185        }
186        Ok(())
187    }
188}
189
190#[derive(Debug, Clone)]
191pub struct Vector {
192    data: Vec<f64>,
193}
194
195impl Vector {
196    pub fn new(size: usize) -> Self {
197        Self {
198            data: vec![0.0; size],
199        }
200    }
201
202    pub fn from_slice(data: &[f64]) -> Self {
203        Self {
204            data: data.to_vec(),
205        }
206    }
207
208    pub fn zeros(size: usize) -> Self {
209        Self::new(size)
210    }
211
212    pub fn ones(size: usize) -> Self {
213        Self {
214            data: vec![1.0; size],
215        }
216    }
217
218    pub fn random(size: usize) -> Self {
219        let mut vector = Self::new(size);
220        for i in 0..size {
221            #[cfg(feature = "wasm")]
222            {
223                vector.data[i] = fastrand::f64();
224            }
225            #[cfg(not(feature = "wasm"))]
226            {
227                vector.data[i] = rand::random::<f64>();
228            }
229        }
230        vector
231    }
232
233    pub fn len(&self) -> usize {
234        self.data.len()
235    }
236
237    pub fn is_empty(&self) -> bool {
238        self.data.is_empty()
239    }
240
241    pub fn data(&self) -> &[f64] {
242        &self.data
243    }
244
245    pub fn data_mut(&mut self) -> &mut [f64] {
246        &mut self.data
247    }
248
249    pub fn get(&self, index: usize) -> f64 {
250        self.data[index]
251    }
252
253    pub fn set(&mut self, index: usize, value: f64) {
254        self.data[index] = value;
255    }
256
257    pub fn dot(&self, other: &Vector) -> f64 {
258        assert_eq!(self.len(), other.len(), "Vector lengths must match for dot product");
259
260        self.data.iter()
261            .zip(other.data.iter())
262            .map(|(a, b)| a * b)
263            .sum()
264    }
265
266    pub fn norm(&self) -> f64 {
267        self.dot(self).sqrt()
268    }
269
270    pub fn normalize(&mut self) {
271        let norm = self.norm();
272        if norm > 0.0 {
273            for x in &mut self.data {
274                *x /= norm;
275            }
276        }
277    }
278
279    pub fn add(&self, other: &Vector) -> Vector {
280        assert_eq!(self.len(), other.len(), "Vector lengths must match for addition");
281
282        let mut result = Vector::new(self.len());
283        for i in 0..self.len() {
284            result.data[i] = self.data[i] + other.data[i];
285        }
286        result
287    }
288
289    pub fn subtract(&self, other: &Vector) -> Vector {
290        assert_eq!(self.len(), other.len(), "Vector lengths must match for subtraction");
291
292        let mut result = Vector::new(self.len());
293        for i in 0..self.len() {
294            result.data[i] = self.data[i] - other.data[i];
295        }
296        result
297    }
298
299    pub fn scale(&self, scalar: f64) -> Vector {
300        let mut result = Vector::new(self.len());
301        for i in 0..self.len() {
302            result.data[i] = self.data[i] * scalar;
303        }
304        result
305    }
306
307    pub fn axpy(&mut self, alpha: f64, x: &Vector) {
308        assert_eq!(self.len(), x.len(), "Vector lengths must match for axpy");
309
310        for i in 0..self.len() {
311            self.data[i] += alpha * x.data[i];
312        }
313    }
314}
315
316impl fmt::Display for Vector {
317    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
318        write!(f, "[")?;
319        for (i, &value) in self.data.iter().enumerate() {
320            if i > 0 {
321                write!(f, ", ")?;
322            }
323            write!(f, "{:8.4}", value)?;
324        }
325        write!(f, "]")
326    }
327}
328
329// Matrix-Vector operations
330impl Matrix {
331    pub fn multiply_vector(&self, vector: &Vector) -> Result<Vector, String> {
332        if self.cols != vector.len() {
333            return Err("Matrix columns must match vector length".to_string());
334        }
335
336        let mut result = Vector::new(self.rows);
337        for i in 0..self.rows {
338            let mut sum = 0.0;
339            for j in 0..self.cols {
340                sum += self.get(i, j) * vector.get(j);
341            }
342            result.set(i, sum);
343        }
344
345        Ok(result)
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_matrix_creation() {
355        let matrix = Matrix::new(3, 3);
356        assert_eq!(matrix.rows(), 3);
357        assert_eq!(matrix.cols(), 3);
358        assert_eq!(matrix.data().len(), 9);
359    }
360
361    #[test]
362    fn test_matrix_identity() {
363        let identity = Matrix::identity(3);
364        assert_eq!(identity.get(0, 0), 1.0);
365        assert_eq!(identity.get(1, 1), 1.0);
366        assert_eq!(identity.get(2, 2), 1.0);
367        assert_eq!(identity.get(0, 1), 0.0);
368    }
369
370    #[test]
371    fn test_vector_operations() {
372        let v1 = Vector::from_slice(&[1.0, 2.0, 3.0]);
373        let v2 = Vector::from_slice(&[4.0, 5.0, 6.0]);
374
375        let dot_product = v1.dot(&v2);
376        assert_eq!(dot_product, 32.0); // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
377
378        let sum = v1.add(&v2);
379        assert_eq!(sum.data(), &[5.0, 7.0, 9.0]);
380    }
381
382    #[test]
383    fn test_matrix_vector_multiply() {
384        let matrix = Matrix::from_slice(&[1.0, 2.0, 3.0, 4.0], 2, 2);
385        let vector = Vector::from_slice(&[1.0, 2.0]);
386
387        let result = matrix.multiply_vector(&vector).unwrap();
388        assert_eq!(result.data(), &[5.0, 11.0]); // [1*1+2*2, 3*1+4*2] = [5, 11]
389    }
390}