Skip to main content

sublinear_solver/
math_wasm.rs

1use std::fmt;
2
3#[derive(Debug, Clone)]
4pub struct Matrix {
5    data: Vec<f64>,
6    rows: usize,
7    cols: usize,
8}
9
10impl Matrix {
11    pub fn new(rows: usize, cols: usize) -> Self {
12        Self {
13            data: vec![0.0; rows * cols],
14            rows,
15            cols,
16        }
17    }
18
19    pub fn from_slice(data: &[f64], rows: usize, cols: usize) -> Self {
20        assert_eq!(
21            data.len(),
22            rows * cols,
23            "Data length must match matrix dimensions"
24        );
25        Self {
26            data: data.to_vec(),
27            rows,
28            cols,
29        }
30    }
31
32    pub fn identity(size: usize) -> Self {
33        let mut matrix = Self::new(size, size);
34        for i in 0..size {
35            matrix.data[i * size + i] = 1.0;
36        }
37        matrix
38    }
39
40    pub fn random(rows: usize, cols: usize) -> Self {
41        let mut matrix = Self::new(rows, cols);
42        for i in 0..matrix.data.len() {
43            #[cfg(feature = "wasm")]
44            {
45                matrix.data[i] = fastrand::f64();
46            }
47            #[cfg(not(feature = "wasm"))]
48            {
49                matrix.data[i] = rand::random::<f64>();
50            }
51        }
52        matrix
53    }
54
55    pub fn rows(&self) -> usize {
56        self.rows
57    }
58
59    pub fn cols(&self) -> usize {
60        self.cols
61    }
62
63    pub fn data(&self) -> &[f64] {
64        &self.data
65    }
66
67    pub fn data_mut(&mut self) -> &mut [f64] {
68        &mut self.data
69    }
70
71    pub fn get(&self, row: usize, col: usize) -> f64 {
72        assert!(row < self.rows && col < self.cols, "Index out of bounds");
73        self.data[row * self.cols + col]
74    }
75
76    pub fn set(&mut self, row: usize, col: usize, value: f64) {
77        assert!(row < self.rows && col < self.cols, "Index out of bounds");
78        self.data[row * self.cols + col] = value;
79    }
80
81    pub fn multiply(&self, other: &Matrix) -> Result<Matrix, String> {
82        if self.cols != other.rows {
83            return Err("Matrix dimensions incompatible for multiplication".to_string());
84        }
85
86        let mut result = Matrix::new(self.rows, other.cols);
87
88        for i in 0..self.rows {
89            for j in 0..other.cols {
90                let mut sum = 0.0;
91                for k in 0..self.cols {
92                    sum += self.get(i, k) * other.get(k, j);
93                }
94                result.set(i, j, sum);
95            }
96        }
97
98        Ok(result)
99    }
100
101    pub fn transpose(&self) -> Matrix {
102        let mut result = Matrix::new(self.cols, self.rows);
103        for i in 0..self.rows {
104            for j in 0..self.cols {
105                result.set(j, i, self.get(i, j));
106            }
107        }
108        result
109    }
110
111    pub fn is_symmetric(&self) -> bool {
112        if self.rows != self.cols {
113            return false;
114        }
115
116        for i in 0..self.rows {
117            for j in 0..self.cols {
118                if (self.get(i, j) - self.get(j, i)).abs() > 1e-10 {
119                    return false;
120                }
121            }
122        }
123        true
124    }
125
126    pub fn is_positive_definite(&self) -> bool {
127        if !self.is_symmetric() {
128            return false;
129        }
130
131        // Simple check using Sylvester's criterion for small matrices
132        // For larger matrices, this should use Cholesky decomposition
133        if self.rows <= 3 {
134            return self.check_sylvester_criterion();
135        }
136
137        // For larger matrices, approximate check
138        true
139    }
140
141    fn check_sylvester_criterion(&self) -> bool {
142        for k in 1..=self.rows {
143            let det = self.leading_principal_minor(k);
144            if det <= 0.0 {
145                return false;
146            }
147        }
148        true
149    }
150
151    fn leading_principal_minor(&self, k: usize) -> f64 {
152        if k == 1 {
153            return self.get(0, 0);
154        }
155        if k == 2 {
156            return self.get(0, 0) * self.get(1, 1) - self.get(0, 1) * self.get(1, 0);
157        }
158        if k == 3 {
159            let a = self.get(0, 0);
160            let b = self.get(0, 1);
161            let c = self.get(0, 2);
162            let d = self.get(1, 0);
163            let e = self.get(1, 1);
164            let f = self.get(1, 2);
165            let g = self.get(2, 0);
166            let h = self.get(2, 1);
167            let i = self.get(2, 2);
168
169            return a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g);
170        }
171
172        // For larger matrices, use a simplified approximation
173        1.0
174    }
175}
176
177impl fmt::Display for Matrix {
178    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179        for i in 0..self.rows {
180            write!(f, "[")?;
181            for j in 0..self.cols {
182                if j > 0 {
183                    write!(f, ", ")?;
184                }
185                write!(f, "{:8.4}", self.get(i, j))?;
186            }
187            writeln!(f, "]")?;
188        }
189        Ok(())
190    }
191}
192
193#[derive(Debug, Clone)]
194pub struct Vector {
195    data: Vec<f64>,
196}
197
198impl Vector {
199    pub fn new(size: usize) -> Self {
200        Self {
201            data: vec![0.0; size],
202        }
203    }
204
205    pub fn from_slice(data: &[f64]) -> Self {
206        Self {
207            data: data.to_vec(),
208        }
209    }
210
211    pub fn zeros(size: usize) -> Self {
212        Self::new(size)
213    }
214
215    pub fn ones(size: usize) -> Self {
216        Self {
217            data: vec![1.0; size],
218        }
219    }
220
221    pub fn random(size: usize) -> Self {
222        let mut vector = Self::new(size);
223        for i in 0..size {
224            #[cfg(feature = "wasm")]
225            {
226                vector.data[i] = fastrand::f64();
227            }
228            #[cfg(not(feature = "wasm"))]
229            {
230                vector.data[i] = rand::random::<f64>();
231            }
232        }
233        vector
234    }
235
236    pub fn len(&self) -> usize {
237        self.data.len()
238    }
239
240    pub fn is_empty(&self) -> bool {
241        self.data.is_empty()
242    }
243
244    pub fn data(&self) -> &[f64] {
245        &self.data
246    }
247
248    pub fn data_mut(&mut self) -> &mut [f64] {
249        &mut self.data
250    }
251
252    pub fn get(&self, index: usize) -> f64 {
253        self.data[index]
254    }
255
256    pub fn set(&mut self, index: usize, value: f64) {
257        self.data[index] = value;
258    }
259
260    pub fn dot(&self, other: &Vector) -> f64 {
261        assert_eq!(
262            self.len(),
263            other.len(),
264            "Vector lengths must match for dot product"
265        );
266
267        self.data
268            .iter()
269            .zip(other.data.iter())
270            .map(|(a, b)| a * b)
271            .sum()
272    }
273
274    pub fn norm(&self) -> f64 {
275        self.dot(self).sqrt()
276    }
277
278    pub fn normalize(&mut self) {
279        let norm = self.norm();
280        if norm > 0.0 {
281            for x in &mut self.data {
282                *x /= norm;
283            }
284        }
285    }
286
287    pub fn add(&self, other: &Vector) -> Vector {
288        assert_eq!(
289            self.len(),
290            other.len(),
291            "Vector lengths must match for addition"
292        );
293
294        let mut result = Vector::new(self.len());
295        for i in 0..self.len() {
296            result.data[i] = self.data[i] + other.data[i];
297        }
298        result
299    }
300
301    pub fn subtract(&self, other: &Vector) -> Vector {
302        assert_eq!(
303            self.len(),
304            other.len(),
305            "Vector lengths must match for subtraction"
306        );
307
308        let mut result = Vector::new(self.len());
309        for i in 0..self.len() {
310            result.data[i] = self.data[i] - other.data[i];
311        }
312        result
313    }
314
315    pub fn scale(&self, scalar: f64) -> Vector {
316        let mut result = Vector::new(self.len());
317        for i in 0..self.len() {
318            result.data[i] = self.data[i] * scalar;
319        }
320        result
321    }
322
323    pub fn axpy(&mut self, alpha: f64, x: &Vector) {
324        assert_eq!(self.len(), x.len(), "Vector lengths must match for axpy");
325
326        for i in 0..self.len() {
327            self.data[i] += alpha * x.data[i];
328        }
329    }
330}
331
332impl fmt::Display for Vector {
333    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
334        write!(f, "[")?;
335        for (i, &value) in self.data.iter().enumerate() {
336            if i > 0 {
337                write!(f, ", ")?;
338            }
339            write!(f, "{:8.4}", value)?;
340        }
341        write!(f, "]")
342    }
343}
344
345// Matrix-Vector operations
346impl Matrix {
347    pub fn multiply_vector(&self, vector: &Vector) -> Result<Vector, String> {
348        if self.cols != vector.len() {
349            return Err("Matrix columns must match vector length".to_string());
350        }
351
352        let mut result = Vector::new(self.rows);
353        for i in 0..self.rows {
354            let mut sum = 0.0;
355            for j in 0..self.cols {
356                sum += self.get(i, j) * vector.get(j);
357            }
358            result.set(i, sum);
359        }
360
361        Ok(result)
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[test]
370    fn test_matrix_creation() {
371        let matrix = Matrix::new(3, 3);
372        assert_eq!(matrix.rows(), 3);
373        assert_eq!(matrix.cols(), 3);
374        assert_eq!(matrix.data().len(), 9);
375    }
376
377    #[test]
378    fn test_matrix_identity() {
379        let identity = Matrix::identity(3);
380        assert_eq!(identity.get(0, 0), 1.0);
381        assert_eq!(identity.get(1, 1), 1.0);
382        assert_eq!(identity.get(2, 2), 1.0);
383        assert_eq!(identity.get(0, 1), 0.0);
384    }
385
386    #[test]
387    fn test_vector_operations() {
388        let v1 = Vector::from_slice(&[1.0, 2.0, 3.0]);
389        let v2 = Vector::from_slice(&[4.0, 5.0, 6.0]);
390
391        let dot_product = v1.dot(&v2);
392        assert_eq!(dot_product, 32.0); // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
393
394        let sum = v1.add(&v2);
395        assert_eq!(sum.data(), &[5.0, 7.0, 9.0]);
396    }
397
398    #[test]
399    fn test_matrix_vector_multiply() {
400        let matrix = Matrix::from_slice(&[1.0, 2.0, 3.0, 4.0], 2, 2);
401        let vector = Vector::from_slice(&[1.0, 2.0]);
402
403        let result = matrix.multiply_vector(&vector).unwrap();
404        assert_eq!(result.data(), &[5.0, 11.0]); // [1*1+2*2, 3*1+4*2] = [5, 11]
405    }
406}