1use std::fmt;
2
3#[derive(Debug, Clone)]
4pub struct Matrix {
5 data: Vec<f64>,
6 rows: usize,
7 cols: usize,
8}
9
10impl Matrix {
11 pub fn new(rows: usize, cols: usize) -> Self {
12 Self {
13 data: vec![0.0; rows * cols],
14 rows,
15 cols,
16 }
17 }
18
19 pub fn from_slice(data: &[f64], rows: usize, cols: usize) -> Self {
20 assert_eq!(
21 data.len(),
22 rows * cols,
23 "Data length must match matrix dimensions"
24 );
25 Self {
26 data: data.to_vec(),
27 rows,
28 cols,
29 }
30 }
31
32 pub fn identity(size: usize) -> Self {
33 let mut matrix = Self::new(size, size);
34 for i in 0..size {
35 matrix.data[i * size + i] = 1.0;
36 }
37 matrix
38 }
39
40 pub fn random(rows: usize, cols: usize) -> Self {
41 let mut matrix = Self::new(rows, cols);
42 for i in 0..matrix.data.len() {
43 #[cfg(feature = "wasm")]
44 {
45 matrix.data[i] = fastrand::f64();
46 }
47 #[cfg(not(feature = "wasm"))]
48 {
49 matrix.data[i] = rand::random::<f64>();
50 }
51 }
52 matrix
53 }
54
55 pub fn rows(&self) -> usize {
56 self.rows
57 }
58
59 pub fn cols(&self) -> usize {
60 self.cols
61 }
62
63 pub fn data(&self) -> &[f64] {
64 &self.data
65 }
66
67 pub fn data_mut(&mut self) -> &mut [f64] {
68 &mut self.data
69 }
70
71 pub fn get(&self, row: usize, col: usize) -> f64 {
72 assert!(row < self.rows && col < self.cols, "Index out of bounds");
73 self.data[row * self.cols + col]
74 }
75
76 pub fn set(&mut self, row: usize, col: usize, value: f64) {
77 assert!(row < self.rows && col < self.cols, "Index out of bounds");
78 self.data[row * self.cols + col] = value;
79 }
80
81 pub fn multiply(&self, other: &Matrix) -> Result<Matrix, String> {
82 if self.cols != other.rows {
83 return Err("Matrix dimensions incompatible for multiplication".to_string());
84 }
85
86 let mut result = Matrix::new(self.rows, other.cols);
87
88 for i in 0..self.rows {
89 for j in 0..other.cols {
90 let mut sum = 0.0;
91 for k in 0..self.cols {
92 sum += self.get(i, k) * other.get(k, j);
93 }
94 result.set(i, j, sum);
95 }
96 }
97
98 Ok(result)
99 }
100
101 pub fn transpose(&self) -> Matrix {
102 let mut result = Matrix::new(self.cols, self.rows);
103 for i in 0..self.rows {
104 for j in 0..self.cols {
105 result.set(j, i, self.get(i, j));
106 }
107 }
108 result
109 }
110
111 pub fn is_symmetric(&self) -> bool {
112 if self.rows != self.cols {
113 return false;
114 }
115
116 for i in 0..self.rows {
117 for j in 0..self.cols {
118 if (self.get(i, j) - self.get(j, i)).abs() > 1e-10 {
119 return false;
120 }
121 }
122 }
123 true
124 }
125
126 pub fn is_positive_definite(&self) -> bool {
127 if !self.is_symmetric() {
128 return false;
129 }
130
131 if self.rows <= 3 {
134 return self.check_sylvester_criterion();
135 }
136
137 true
139 }
140
141 fn check_sylvester_criterion(&self) -> bool {
142 for k in 1..=self.rows {
143 let det = self.leading_principal_minor(k);
144 if det <= 0.0 {
145 return false;
146 }
147 }
148 true
149 }
150
151 fn leading_principal_minor(&self, k: usize) -> f64 {
152 if k == 1 {
153 return self.get(0, 0);
154 }
155 if k == 2 {
156 return self.get(0, 0) * self.get(1, 1) - self.get(0, 1) * self.get(1, 0);
157 }
158 if k == 3 {
159 let a = self.get(0, 0);
160 let b = self.get(0, 1);
161 let c = self.get(0, 2);
162 let d = self.get(1, 0);
163 let e = self.get(1, 1);
164 let f = self.get(1, 2);
165 let g = self.get(2, 0);
166 let h = self.get(2, 1);
167 let i = self.get(2, 2);
168
169 return a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g);
170 }
171
172 1.0
174 }
175}
176
177impl fmt::Display for Matrix {
178 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179 for i in 0..self.rows {
180 write!(f, "[")?;
181 for j in 0..self.cols {
182 if j > 0 {
183 write!(f, ", ")?;
184 }
185 write!(f, "{:8.4}", self.get(i, j))?;
186 }
187 writeln!(f, "]")?;
188 }
189 Ok(())
190 }
191}
192
193#[derive(Debug, Clone)]
194pub struct Vector {
195 data: Vec<f64>,
196}
197
198impl Vector {
199 pub fn new(size: usize) -> Self {
200 Self {
201 data: vec![0.0; size],
202 }
203 }
204
205 pub fn from_slice(data: &[f64]) -> Self {
206 Self {
207 data: data.to_vec(),
208 }
209 }
210
211 pub fn zeros(size: usize) -> Self {
212 Self::new(size)
213 }
214
215 pub fn ones(size: usize) -> Self {
216 Self {
217 data: vec![1.0; size],
218 }
219 }
220
221 pub fn random(size: usize) -> Self {
222 let mut vector = Self::new(size);
223 for i in 0..size {
224 #[cfg(feature = "wasm")]
225 {
226 vector.data[i] = fastrand::f64();
227 }
228 #[cfg(not(feature = "wasm"))]
229 {
230 vector.data[i] = rand::random::<f64>();
231 }
232 }
233 vector
234 }
235
236 pub fn len(&self) -> usize {
237 self.data.len()
238 }
239
240 pub fn is_empty(&self) -> bool {
241 self.data.is_empty()
242 }
243
244 pub fn data(&self) -> &[f64] {
245 &self.data
246 }
247
248 pub fn data_mut(&mut self) -> &mut [f64] {
249 &mut self.data
250 }
251
252 pub fn get(&self, index: usize) -> f64 {
253 self.data[index]
254 }
255
256 pub fn set(&mut self, index: usize, value: f64) {
257 self.data[index] = value;
258 }
259
260 pub fn dot(&self, other: &Vector) -> f64 {
261 assert_eq!(
262 self.len(),
263 other.len(),
264 "Vector lengths must match for dot product"
265 );
266
267 self.data
268 .iter()
269 .zip(other.data.iter())
270 .map(|(a, b)| a * b)
271 .sum()
272 }
273
274 pub fn norm(&self) -> f64 {
275 self.dot(self).sqrt()
276 }
277
278 pub fn normalize(&mut self) {
279 let norm = self.norm();
280 if norm > 0.0 {
281 for x in &mut self.data {
282 *x /= norm;
283 }
284 }
285 }
286
287 pub fn add(&self, other: &Vector) -> Vector {
288 assert_eq!(
289 self.len(),
290 other.len(),
291 "Vector lengths must match for addition"
292 );
293
294 let mut result = Vector::new(self.len());
295 for i in 0..self.len() {
296 result.data[i] = self.data[i] + other.data[i];
297 }
298 result
299 }
300
301 pub fn subtract(&self, other: &Vector) -> Vector {
302 assert_eq!(
303 self.len(),
304 other.len(),
305 "Vector lengths must match for subtraction"
306 );
307
308 let mut result = Vector::new(self.len());
309 for i in 0..self.len() {
310 result.data[i] = self.data[i] - other.data[i];
311 }
312 result
313 }
314
315 pub fn scale(&self, scalar: f64) -> Vector {
316 let mut result = Vector::new(self.len());
317 for i in 0..self.len() {
318 result.data[i] = self.data[i] * scalar;
319 }
320 result
321 }
322
323 pub fn axpy(&mut self, alpha: f64, x: &Vector) {
324 assert_eq!(self.len(), x.len(), "Vector lengths must match for axpy");
325
326 for i in 0..self.len() {
327 self.data[i] += alpha * x.data[i];
328 }
329 }
330}
331
332impl fmt::Display for Vector {
333 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
334 write!(f, "[")?;
335 for (i, &value) in self.data.iter().enumerate() {
336 if i > 0 {
337 write!(f, ", ")?;
338 }
339 write!(f, "{:8.4}", value)?;
340 }
341 write!(f, "]")
342 }
343}
344
345impl Matrix {
347 pub fn multiply_vector(&self, vector: &Vector) -> Result<Vector, String> {
348 if self.cols != vector.len() {
349 return Err("Matrix columns must match vector length".to_string());
350 }
351
352 let mut result = Vector::new(self.rows);
353 for i in 0..self.rows {
354 let mut sum = 0.0;
355 for j in 0..self.cols {
356 sum += self.get(i, j) * vector.get(j);
357 }
358 result.set(i, sum);
359 }
360
361 Ok(result)
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[test]
370 fn test_matrix_creation() {
371 let matrix = Matrix::new(3, 3);
372 assert_eq!(matrix.rows(), 3);
373 assert_eq!(matrix.cols(), 3);
374 assert_eq!(matrix.data().len(), 9);
375 }
376
377 #[test]
378 fn test_matrix_identity() {
379 let identity = Matrix::identity(3);
380 assert_eq!(identity.get(0, 0), 1.0);
381 assert_eq!(identity.get(1, 1), 1.0);
382 assert_eq!(identity.get(2, 2), 1.0);
383 assert_eq!(identity.get(0, 1), 0.0);
384 }
385
386 #[test]
387 fn test_vector_operations() {
388 let v1 = Vector::from_slice(&[1.0, 2.0, 3.0]);
389 let v2 = Vector::from_slice(&[4.0, 5.0, 6.0]);
390
391 let dot_product = v1.dot(&v2);
392 assert_eq!(dot_product, 32.0); let sum = v1.add(&v2);
395 assert_eq!(sum.data(), &[5.0, 7.0, 9.0]);
396 }
397
398 #[test]
399 fn test_matrix_vector_multiply() {
400 let matrix = Matrix::from_slice(&[1.0, 2.0, 3.0, 4.0], 2, 2);
401 let vector = Vector::from_slice(&[1.0, 2.0]);
402
403 let result = matrix.multiply_vector(&vector).unwrap();
404 assert_eq!(result.data(), &[5.0, 11.0]); }
406}