sublinear_solver/
math_wasm.rs

1use std::fmt;
2
3#[derive(Debug, Clone)]
4pub struct Matrix {
5    data: Vec<f64>,
6    rows: usize,
7    cols: usize,
8}
9
10impl Matrix {
11    pub fn new(rows: usize, cols: usize) -> Self {
12        Self {
13            data: vec![0.0; rows * cols],
14            rows,
15            cols,
16        }
17    }
18
19    pub fn from_slice(data: &[f64], rows: usize, cols: usize) -> Self {
20        assert_eq!(data.len(), rows * cols, "Data length must match matrix dimensions");
21        Self {
22            data: data.to_vec(),
23            rows,
24            cols,
25        }
26    }
27
28    pub fn identity(size: usize) -> Self {
29        let mut matrix = Self::new(size, size);
30        for i in 0..size {
31            matrix.data[i * size + i] = 1.0;
32        }
33        matrix
34    }
35
36    pub fn random(rows: usize, cols: usize) -> Self {
37        let mut matrix = Self::new(rows, cols);
38        for i in 0..matrix.data.len() {
39            #[cfg(feature = "wasm")]
40            {
41                matrix.data[i] = fastrand::f64();
42            }
43            #[cfg(not(feature = "wasm"))]
44            {
45                matrix.data[i] = rand::random::<f64>();
46            }
47        }
48        matrix
49    }
50
51    pub fn rows(&self) -> usize {
52        self.rows
53    }
54
55    pub fn cols(&self) -> usize {
56        self.cols
57    }
58
59    pub fn data(&self) -> &[f64] {
60        &self.data
61    }
62
63    pub fn data_mut(&mut self) -> &mut [f64] {
64        &mut self.data
65    }
66
67    pub fn get(&self, row: usize, col: usize) -> f64 {
68        assert!(row < self.rows && col < self.cols, "Index out of bounds");
69        self.data[row * self.cols + col]
70    }
71
72    pub fn set(&mut self, row: usize, col: usize, value: f64) {
73        assert!(row < self.rows && col < self.cols, "Index out of bounds");
74        self.data[row * self.cols + col] = value;
75    }
76
77    pub fn multiply(&self, other: &Matrix) -> Result<Matrix, String> {
78        if self.cols != other.rows {
79            return Err("Matrix dimensions incompatible for multiplication".to_string());
80        }
81
82        let mut result = Matrix::new(self.rows, other.cols);
83
84        for i in 0..self.rows {
85            for j in 0..other.cols {
86                let mut sum = 0.0;
87                for k in 0..self.cols {
88                    sum += self.get(i, k) * other.get(k, j);
89                }
90                result.set(i, j, sum);
91            }
92        }
93
94        Ok(result)
95    }
96
97    pub fn transpose(&self) -> Matrix {
98        let mut result = Matrix::new(self.cols, self.rows);
99        for i in 0..self.rows {
100            for j in 0..self.cols {
101                result.set(j, i, self.get(i, j));
102            }
103        }
104        result
105    }
106
107    pub fn is_symmetric(&self) -> bool {
108        if self.rows != self.cols {
109            return false;
110        }
111
112        for i in 0..self.rows {
113            for j in 0..self.cols {
114                if (self.get(i, j) - self.get(j, i)).abs() > 1e-10 {
115                    return false;
116                }
117            }
118        }
119        true
120    }
121
122    pub fn is_positive_definite(&self) -> bool {
123        if !self.is_symmetric() {
124            return false;
125        }
126
127        // Simple check using Sylvester's criterion for small matrices
128        // For larger matrices, this should use Cholesky decomposition
129        if self.rows <= 3 {
130            return self.check_sylvester_criterion();
131        }
132
133        // For larger matrices, approximate check
134        true
135    }
136
137    fn check_sylvester_criterion(&self) -> bool {
138        for k in 1..=self.rows {
139            let det = self.leading_principal_minor(k);
140            if det <= 0.0 {
141                return false;
142            }
143        }
144        true
145    }
146
147    fn leading_principal_minor(&self, k: usize) -> f64 {
148        if k == 1 {
149            return self.get(0, 0);
150        }
151        if k == 2 {
152            return self.get(0, 0) * self.get(1, 1) - self.get(0, 1) * self.get(1, 0);
153        }
154        if k == 3 {
155            let a = self.get(0, 0);
156            let b = self.get(0, 1);
157            let c = self.get(0, 2);
158            let d = self.get(1, 0);
159            let e = self.get(1, 1);
160            let f = self.get(1, 2);
161            let g = self.get(2, 0);
162            let h = self.get(2, 1);
163            let i = self.get(2, 2);
164
165            return a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g);
166        }
167
168        // For larger matrices, use a simplified approximation
169        1.0
170    }
171}
172
173impl fmt::Display for Matrix {
174    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175        for i in 0..self.rows {
176            write!(f, "[")?;
177            for j in 0..self.cols {
178                if j > 0 {
179                    write!(f, ", ")?;
180                }
181                write!(f, "{:8.4}", self.get(i, j))?;
182            }
183            writeln!(f, "]")?;
184        }
185        Ok(())
186    }
187}
188
189#[derive(Debug, Clone)]
190pub struct Vector {
191    data: Vec<f64>,
192}
193
194impl Vector {
195    pub fn new(size: usize) -> Self {
196        Self {
197            data: vec![0.0; size],
198        }
199    }
200
201    pub fn from_slice(data: &[f64]) -> Self {
202        Self {
203            data: data.to_vec(),
204        }
205    }
206
207    pub fn zeros(size: usize) -> Self {
208        Self::new(size)
209    }
210
211    pub fn ones(size: usize) -> Self {
212        Self {
213            data: vec![1.0; size],
214        }
215    }
216
217    pub fn random(size: usize) -> Self {
218        let mut vector = Self::new(size);
219        for i in 0..size {
220            #[cfg(feature = "wasm")]
221            {
222                vector.data[i] = fastrand::f64();
223            }
224            #[cfg(not(feature = "wasm"))]
225            {
226                vector.data[i] = rand::random::<f64>();
227            }
228        }
229        vector
230    }
231
232    pub fn len(&self) -> usize {
233        self.data.len()
234    }
235
236    pub fn is_empty(&self) -> bool {
237        self.data.is_empty()
238    }
239
240    pub fn data(&self) -> &[f64] {
241        &self.data
242    }
243
244    pub fn data_mut(&mut self) -> &mut [f64] {
245        &mut self.data
246    }
247
248    pub fn get(&self, index: usize) -> f64 {
249        self.data[index]
250    }
251
252    pub fn set(&mut self, index: usize, value: f64) {
253        self.data[index] = value;
254    }
255
256    pub fn dot(&self, other: &Vector) -> f64 {
257        assert_eq!(self.len(), other.len(), "Vector lengths must match for dot product");
258
259        self.data.iter()
260            .zip(other.data.iter())
261            .map(|(a, b)| a * b)
262            .sum()
263    }
264
265    pub fn norm(&self) -> f64 {
266        self.dot(self).sqrt()
267    }
268
269    pub fn normalize(&mut self) {
270        let norm = self.norm();
271        if norm > 0.0 {
272            for x in &mut self.data {
273                *x /= norm;
274            }
275        }
276    }
277
278    pub fn add(&self, other: &Vector) -> Vector {
279        assert_eq!(self.len(), other.len(), "Vector lengths must match for addition");
280
281        let mut result = Vector::new(self.len());
282        for i in 0..self.len() {
283            result.data[i] = self.data[i] + other.data[i];
284        }
285        result
286    }
287
288    pub fn subtract(&self, other: &Vector) -> Vector {
289        assert_eq!(self.len(), other.len(), "Vector lengths must match for subtraction");
290
291        let mut result = Vector::new(self.len());
292        for i in 0..self.len() {
293            result.data[i] = self.data[i] - other.data[i];
294        }
295        result
296    }
297
298    pub fn scale(&self, scalar: f64) -> Vector {
299        let mut result = Vector::new(self.len());
300        for i in 0..self.len() {
301            result.data[i] = self.data[i] * scalar;
302        }
303        result
304    }
305
306    pub fn axpy(&mut self, alpha: f64, x: &Vector) {
307        assert_eq!(self.len(), x.len(), "Vector lengths must match for axpy");
308
309        for i in 0..self.len() {
310            self.data[i] += alpha * x.data[i];
311        }
312    }
313}
314
315impl fmt::Display for Vector {
316    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
317        write!(f, "[")?;
318        for (i, &value) in self.data.iter().enumerate() {
319            if i > 0 {
320                write!(f, ", ")?;
321            }
322            write!(f, "{:8.4}", value)?;
323        }
324        write!(f, "]")
325    }
326}
327
328// Matrix-Vector operations
329impl Matrix {
330    pub fn multiply_vector(&self, vector: &Vector) -> Result<Vector, String> {
331        if self.cols != vector.len() {
332            return Err("Matrix columns must match vector length".to_string());
333        }
334
335        let mut result = Vector::new(self.rows);
336        for i in 0..self.rows {
337            let mut sum = 0.0;
338            for j in 0..self.cols {
339                sum += self.get(i, j) * vector.get(j);
340            }
341            result.set(i, sum);
342        }
343
344        Ok(result)
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_matrix_creation() {
354        let matrix = Matrix::new(3, 3);
355        assert_eq!(matrix.rows(), 3);
356        assert_eq!(matrix.cols(), 3);
357        assert_eq!(matrix.data().len(), 9);
358    }
359
360    #[test]
361    fn test_matrix_identity() {
362        let identity = Matrix::identity(3);
363        assert_eq!(identity.get(0, 0), 1.0);
364        assert_eq!(identity.get(1, 1), 1.0);
365        assert_eq!(identity.get(2, 2), 1.0);
366        assert_eq!(identity.get(0, 1), 0.0);
367    }
368
369    #[test]
370    fn test_vector_operations() {
371        let v1 = Vector::from_slice(&[1.0, 2.0, 3.0]);
372        let v2 = Vector::from_slice(&[4.0, 5.0, 6.0]);
373
374        let dot_product = v1.dot(&v2);
375        assert_eq!(dot_product, 32.0); // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
376
377        let sum = v1.add(&v2);
378        assert_eq!(sum.data(), &[5.0, 7.0, 9.0]);
379    }
380
381    #[test]
382    fn test_matrix_vector_multiply() {
383        let matrix = Matrix::from_slice(&[1.0, 2.0, 3.0, 4.0], 2, 2);
384        let vector = Vector::from_slice(&[1.0, 2.0]);
385
386        let result = matrix.multiply_vector(&vector).unwrap();
387        assert_eq!(result.data(), &[5.0, 11.0]); // [1*1+2*2, 3*1+4*2] = [5, 11]
388    }
389}