1use std::fmt;
2use rand::Rng;
3
4#[derive(Debug, Clone)]
5pub struct Matrix {
6 data: Vec<f64>,
7 rows: usize,
8 cols: usize,
9}
10
11impl Matrix {
12 pub fn new(rows: usize, cols: usize) -> Self {
13 Self {
14 data: vec![0.0; rows * cols],
15 rows,
16 cols,
17 }
18 }
19
20 pub fn from_slice(data: &[f64], rows: usize, cols: usize) -> Self {
21 assert_eq!(data.len(), rows * cols, "Data length must match matrix dimensions");
22 Self {
23 data: data.to_vec(),
24 rows,
25 cols,
26 }
27 }
28
29 pub fn identity(size: usize) -> Self {
30 let mut matrix = Self::new(size, size);
31 for i in 0..size {
32 matrix.data[i * size + i] = 1.0;
33 }
34 matrix
35 }
36
37 pub fn random(rows: usize, cols: usize) -> Self {
38 let mut matrix = Self::new(rows, cols);
39 for i in 0..matrix.data.len() {
40 #[cfg(feature = "wasm")]
41 {
42 matrix.data[i] = fastrand::f64();
43 }
44 #[cfg(not(feature = "wasm"))]
45 {
46 matrix.data[i] = rand::random::<f64>();
47 }
48 }
49 matrix
50 }
51
52 pub fn rows(&self) -> usize {
53 self.rows
54 }
55
56 pub fn cols(&self) -> usize {
57 self.cols
58 }
59
60 pub fn data(&self) -> &[f64] {
61 &self.data
62 }
63
64 pub fn data_mut(&mut self) -> &mut [f64] {
65 &mut self.data
66 }
67
68 pub fn get(&self, row: usize, col: usize) -> f64 {
69 assert!(row < self.rows && col < self.cols, "Index out of bounds");
70 self.data[row * self.cols + col]
71 }
72
73 pub fn set(&mut self, row: usize, col: usize, value: f64) {
74 assert!(row < self.rows && col < self.cols, "Index out of bounds");
75 self.data[row * self.cols + col] = value;
76 }
77
78 pub fn multiply(&self, other: &Matrix) -> Result<Matrix, String> {
79 if self.cols != other.rows {
80 return Err("Matrix dimensions incompatible for multiplication".to_string());
81 }
82
83 let mut result = Matrix::new(self.rows, other.cols);
84
85 for i in 0..self.rows {
86 for j in 0..other.cols {
87 let mut sum = 0.0;
88 for k in 0..self.cols {
89 sum += self.get(i, k) * other.get(k, j);
90 }
91 result.set(i, j, sum);
92 }
93 }
94
95 Ok(result)
96 }
97
98 pub fn transpose(&self) -> Matrix {
99 let mut result = Matrix::new(self.cols, self.rows);
100 for i in 0..self.rows {
101 for j in 0..self.cols {
102 result.set(j, i, self.get(i, j));
103 }
104 }
105 result
106 }
107
108 pub fn is_symmetric(&self) -> bool {
109 if self.rows != self.cols {
110 return false;
111 }
112
113 for i in 0..self.rows {
114 for j in 0..self.cols {
115 if (self.get(i, j) - self.get(j, i)).abs() > 1e-10 {
116 return false;
117 }
118 }
119 }
120 true
121 }
122
123 pub fn is_positive_definite(&self) -> bool {
124 if !self.is_symmetric() {
125 return false;
126 }
127
128 if self.rows <= 3 {
131 return self.check_sylvester_criterion();
132 }
133
134 true
136 }
137
138 fn check_sylvester_criterion(&self) -> bool {
139 for k in 1..=self.rows {
140 let det = self.leading_principal_minor(k);
141 if det <= 0.0 {
142 return false;
143 }
144 }
145 true
146 }
147
148 fn leading_principal_minor(&self, k: usize) -> f64 {
149 if k == 1 {
150 return self.get(0, 0);
151 }
152 if k == 2 {
153 return self.get(0, 0) * self.get(1, 1) - self.get(0, 1) * self.get(1, 0);
154 }
155 if k == 3 {
156 let a = self.get(0, 0);
157 let b = self.get(0, 1);
158 let c = self.get(0, 2);
159 let d = self.get(1, 0);
160 let e = self.get(1, 1);
161 let f = self.get(1, 2);
162 let g = self.get(2, 0);
163 let h = self.get(2, 1);
164 let i = self.get(2, 2);
165
166 return a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g);
167 }
168
169 1.0
171 }
172}
173
174impl fmt::Display for Matrix {
175 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176 for i in 0..self.rows {
177 write!(f, "[")?;
178 for j in 0..self.cols {
179 if j > 0 {
180 write!(f, ", ")?;
181 }
182 write!(f, "{:8.4}", self.get(i, j))?;
183 }
184 writeln!(f, "]")?;
185 }
186 Ok(())
187 }
188}
189
190#[derive(Debug, Clone)]
191pub struct Vector {
192 data: Vec<f64>,
193}
194
195impl Vector {
196 pub fn new(size: usize) -> Self {
197 Self {
198 data: vec![0.0; size],
199 }
200 }
201
202 pub fn from_slice(data: &[f64]) -> Self {
203 Self {
204 data: data.to_vec(),
205 }
206 }
207
208 pub fn zeros(size: usize) -> Self {
209 Self::new(size)
210 }
211
212 pub fn ones(size: usize) -> Self {
213 Self {
214 data: vec![1.0; size],
215 }
216 }
217
218 pub fn random(size: usize) -> Self {
219 let mut vector = Self::new(size);
220 for i in 0..size {
221 #[cfg(feature = "wasm")]
222 {
223 vector.data[i] = fastrand::f64();
224 }
225 #[cfg(not(feature = "wasm"))]
226 {
227 vector.data[i] = rand::random::<f64>();
228 }
229 }
230 vector
231 }
232
233 pub fn len(&self) -> usize {
234 self.data.len()
235 }
236
237 pub fn is_empty(&self) -> bool {
238 self.data.is_empty()
239 }
240
241 pub fn data(&self) -> &[f64] {
242 &self.data
243 }
244
245 pub fn data_mut(&mut self) -> &mut [f64] {
246 &mut self.data
247 }
248
249 pub fn get(&self, index: usize) -> f64 {
250 self.data[index]
251 }
252
253 pub fn set(&mut self, index: usize, value: f64) {
254 self.data[index] = value;
255 }
256
257 pub fn dot(&self, other: &Vector) -> f64 {
258 assert_eq!(self.len(), other.len(), "Vector lengths must match for dot product");
259
260 self.data.iter()
261 .zip(other.data.iter())
262 .map(|(a, b)| a * b)
263 .sum()
264 }
265
266 pub fn norm(&self) -> f64 {
267 self.dot(self).sqrt()
268 }
269
270 pub fn normalize(&mut self) {
271 let norm = self.norm();
272 if norm > 0.0 {
273 for x in &mut self.data {
274 *x /= norm;
275 }
276 }
277 }
278
279 pub fn add(&self, other: &Vector) -> Vector {
280 assert_eq!(self.len(), other.len(), "Vector lengths must match for addition");
281
282 let mut result = Vector::new(self.len());
283 for i in 0..self.len() {
284 result.data[i] = self.data[i] + other.data[i];
285 }
286 result
287 }
288
289 pub fn subtract(&self, other: &Vector) -> Vector {
290 assert_eq!(self.len(), other.len(), "Vector lengths must match for subtraction");
291
292 let mut result = Vector::new(self.len());
293 for i in 0..self.len() {
294 result.data[i] = self.data[i] - other.data[i];
295 }
296 result
297 }
298
299 pub fn scale(&self, scalar: f64) -> Vector {
300 let mut result = Vector::new(self.len());
301 for i in 0..self.len() {
302 result.data[i] = self.data[i] * scalar;
303 }
304 result
305 }
306
307 pub fn axpy(&mut self, alpha: f64, x: &Vector) {
308 assert_eq!(self.len(), x.len(), "Vector lengths must match for axpy");
309
310 for i in 0..self.len() {
311 self.data[i] += alpha * x.data[i];
312 }
313 }
314}
315
316impl fmt::Display for Vector {
317 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
318 write!(f, "[")?;
319 for (i, &value) in self.data.iter().enumerate() {
320 if i > 0 {
321 write!(f, ", ")?;
322 }
323 write!(f, "{:8.4}", value)?;
324 }
325 write!(f, "]")
326 }
327}
328
329impl Matrix {
331 pub fn multiply_vector(&self, vector: &Vector) -> Result<Vector, String> {
332 if self.cols != vector.len() {
333 return Err("Matrix columns must match vector length".to_string());
334 }
335
336 let mut result = Vector::new(self.rows);
337 for i in 0..self.rows {
338 let mut sum = 0.0;
339 for j in 0..self.cols {
340 sum += self.get(i, j) * vector.get(j);
341 }
342 result.set(i, sum);
343 }
344
345 Ok(result)
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_matrix_creation() {
355 let matrix = Matrix::new(3, 3);
356 assert_eq!(matrix.rows(), 3);
357 assert_eq!(matrix.cols(), 3);
358 assert_eq!(matrix.data().len(), 9);
359 }
360
361 #[test]
362 fn test_matrix_identity() {
363 let identity = Matrix::identity(3);
364 assert_eq!(identity.get(0, 0), 1.0);
365 assert_eq!(identity.get(1, 1), 1.0);
366 assert_eq!(identity.get(2, 2), 1.0);
367 assert_eq!(identity.get(0, 1), 0.0);
368 }
369
370 #[test]
371 fn test_vector_operations() {
372 let v1 = Vector::from_slice(&[1.0, 2.0, 3.0]);
373 let v2 = Vector::from_slice(&[4.0, 5.0, 6.0]);
374
375 let dot_product = v1.dot(&v2);
376 assert_eq!(dot_product, 32.0); let sum = v1.add(&v2);
379 assert_eq!(sum.data(), &[5.0, 7.0, 9.0]);
380 }
381
382 #[test]
383 fn test_matrix_vector_multiply() {
384 let matrix = Matrix::from_slice(&[1.0, 2.0, 3.0, 4.0], 2, 2);
385 let vector = Vector::from_slice(&[1.0, 2.0]);
386
387 let result = matrix.multiply_vector(&vector).unwrap();
388 assert_eq!(result.data(), &[5.0, 11.0]); }
390}