1use std::fmt;
2
3#[derive(Debug, Clone)]
4pub struct Matrix {
5 data: Vec<f64>,
6 rows: usize,
7 cols: usize,
8}
9
10impl Matrix {
11 pub fn new(rows: usize, cols: usize) -> Self {
12 Self {
13 data: vec![0.0; rows * cols],
14 rows,
15 cols,
16 }
17 }
18
19 pub fn from_slice(data: &[f64], rows: usize, cols: usize) -> Self {
20 assert_eq!(data.len(), rows * cols, "Data length must match matrix dimensions");
21 Self {
22 data: data.to_vec(),
23 rows,
24 cols,
25 }
26 }
27
28 pub fn identity(size: usize) -> Self {
29 let mut matrix = Self::new(size, size);
30 for i in 0..size {
31 matrix.data[i * size + i] = 1.0;
32 }
33 matrix
34 }
35
36 pub fn random(rows: usize, cols: usize) -> Self {
37 let mut matrix = Self::new(rows, cols);
38 for i in 0..matrix.data.len() {
39 #[cfg(feature = "wasm")]
40 {
41 matrix.data[i] = fastrand::f64();
42 }
43 #[cfg(not(feature = "wasm"))]
44 {
45 matrix.data[i] = rand::random::<f64>();
46 }
47 }
48 matrix
49 }
50
51 pub fn rows(&self) -> usize {
52 self.rows
53 }
54
55 pub fn cols(&self) -> usize {
56 self.cols
57 }
58
59 pub fn data(&self) -> &[f64] {
60 &self.data
61 }
62
63 pub fn data_mut(&mut self) -> &mut [f64] {
64 &mut self.data
65 }
66
67 pub fn get(&self, row: usize, col: usize) -> f64 {
68 assert!(row < self.rows && col < self.cols, "Index out of bounds");
69 self.data[row * self.cols + col]
70 }
71
72 pub fn set(&mut self, row: usize, col: usize, value: f64) {
73 assert!(row < self.rows && col < self.cols, "Index out of bounds");
74 self.data[row * self.cols + col] = value;
75 }
76
77 pub fn multiply(&self, other: &Matrix) -> Result<Matrix, String> {
78 if self.cols != other.rows {
79 return Err("Matrix dimensions incompatible for multiplication".to_string());
80 }
81
82 let mut result = Matrix::new(self.rows, other.cols);
83
84 for i in 0..self.rows {
85 for j in 0..other.cols {
86 let mut sum = 0.0;
87 for k in 0..self.cols {
88 sum += self.get(i, k) * other.get(k, j);
89 }
90 result.set(i, j, sum);
91 }
92 }
93
94 Ok(result)
95 }
96
97 pub fn transpose(&self) -> Matrix {
98 let mut result = Matrix::new(self.cols, self.rows);
99 for i in 0..self.rows {
100 for j in 0..self.cols {
101 result.set(j, i, self.get(i, j));
102 }
103 }
104 result
105 }
106
107 pub fn is_symmetric(&self) -> bool {
108 if self.rows != self.cols {
109 return false;
110 }
111
112 for i in 0..self.rows {
113 for j in 0..self.cols {
114 if (self.get(i, j) - self.get(j, i)).abs() > 1e-10 {
115 return false;
116 }
117 }
118 }
119 true
120 }
121
122 pub fn is_positive_definite(&self) -> bool {
123 if !self.is_symmetric() {
124 return false;
125 }
126
127 if self.rows <= 3 {
130 return self.check_sylvester_criterion();
131 }
132
133 true
135 }
136
137 fn check_sylvester_criterion(&self) -> bool {
138 for k in 1..=self.rows {
139 let det = self.leading_principal_minor(k);
140 if det <= 0.0 {
141 return false;
142 }
143 }
144 true
145 }
146
147 fn leading_principal_minor(&self, k: usize) -> f64 {
148 if k == 1 {
149 return self.get(0, 0);
150 }
151 if k == 2 {
152 return self.get(0, 0) * self.get(1, 1) - self.get(0, 1) * self.get(1, 0);
153 }
154 if k == 3 {
155 let a = self.get(0, 0);
156 let b = self.get(0, 1);
157 let c = self.get(0, 2);
158 let d = self.get(1, 0);
159 let e = self.get(1, 1);
160 let f = self.get(1, 2);
161 let g = self.get(2, 0);
162 let h = self.get(2, 1);
163 let i = self.get(2, 2);
164
165 return a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g);
166 }
167
168 1.0
170 }
171}
172
173impl fmt::Display for Matrix {
174 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175 for i in 0..self.rows {
176 write!(f, "[")?;
177 for j in 0..self.cols {
178 if j > 0 {
179 write!(f, ", ")?;
180 }
181 write!(f, "{:8.4}", self.get(i, j))?;
182 }
183 writeln!(f, "]")?;
184 }
185 Ok(())
186 }
187}
188
189#[derive(Debug, Clone)]
190pub struct Vector {
191 data: Vec<f64>,
192}
193
194impl Vector {
195 pub fn new(size: usize) -> Self {
196 Self {
197 data: vec![0.0; size],
198 }
199 }
200
201 pub fn from_slice(data: &[f64]) -> Self {
202 Self {
203 data: data.to_vec(),
204 }
205 }
206
207 pub fn zeros(size: usize) -> Self {
208 Self::new(size)
209 }
210
211 pub fn ones(size: usize) -> Self {
212 Self {
213 data: vec![1.0; size],
214 }
215 }
216
217 pub fn random(size: usize) -> Self {
218 let mut vector = Self::new(size);
219 for i in 0..size {
220 #[cfg(feature = "wasm")]
221 {
222 vector.data[i] = fastrand::f64();
223 }
224 #[cfg(not(feature = "wasm"))]
225 {
226 vector.data[i] = rand::random::<f64>();
227 }
228 }
229 vector
230 }
231
232 pub fn len(&self) -> usize {
233 self.data.len()
234 }
235
236 pub fn is_empty(&self) -> bool {
237 self.data.is_empty()
238 }
239
240 pub fn data(&self) -> &[f64] {
241 &self.data
242 }
243
244 pub fn data_mut(&mut self) -> &mut [f64] {
245 &mut self.data
246 }
247
248 pub fn get(&self, index: usize) -> f64 {
249 self.data[index]
250 }
251
252 pub fn set(&mut self, index: usize, value: f64) {
253 self.data[index] = value;
254 }
255
256 pub fn dot(&self, other: &Vector) -> f64 {
257 assert_eq!(self.len(), other.len(), "Vector lengths must match for dot product");
258
259 self.data.iter()
260 .zip(other.data.iter())
261 .map(|(a, b)| a * b)
262 .sum()
263 }
264
265 pub fn norm(&self) -> f64 {
266 self.dot(self).sqrt()
267 }
268
269 pub fn normalize(&mut self) {
270 let norm = self.norm();
271 if norm > 0.0 {
272 for x in &mut self.data {
273 *x /= norm;
274 }
275 }
276 }
277
278 pub fn add(&self, other: &Vector) -> Vector {
279 assert_eq!(self.len(), other.len(), "Vector lengths must match for addition");
280
281 let mut result = Vector::new(self.len());
282 for i in 0..self.len() {
283 result.data[i] = self.data[i] + other.data[i];
284 }
285 result
286 }
287
288 pub fn subtract(&self, other: &Vector) -> Vector {
289 assert_eq!(self.len(), other.len(), "Vector lengths must match for subtraction");
290
291 let mut result = Vector::new(self.len());
292 for i in 0..self.len() {
293 result.data[i] = self.data[i] - other.data[i];
294 }
295 result
296 }
297
298 pub fn scale(&self, scalar: f64) -> Vector {
299 let mut result = Vector::new(self.len());
300 for i in 0..self.len() {
301 result.data[i] = self.data[i] * scalar;
302 }
303 result
304 }
305
306 pub fn axpy(&mut self, alpha: f64, x: &Vector) {
307 assert_eq!(self.len(), x.len(), "Vector lengths must match for axpy");
308
309 for i in 0..self.len() {
310 self.data[i] += alpha * x.data[i];
311 }
312 }
313}
314
315impl fmt::Display for Vector {
316 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
317 write!(f, "[")?;
318 for (i, &value) in self.data.iter().enumerate() {
319 if i > 0 {
320 write!(f, ", ")?;
321 }
322 write!(f, "{:8.4}", value)?;
323 }
324 write!(f, "]")
325 }
326}
327
328impl Matrix {
330 pub fn multiply_vector(&self, vector: &Vector) -> Result<Vector, String> {
331 if self.cols != vector.len() {
332 return Err("Matrix columns must match vector length".to_string());
333 }
334
335 let mut result = Vector::new(self.rows);
336 for i in 0..self.rows {
337 let mut sum = 0.0;
338 for j in 0..self.cols {
339 sum += self.get(i, j) * vector.get(j);
340 }
341 result.set(i, sum);
342 }
343
344 Ok(result)
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn test_matrix_creation() {
354 let matrix = Matrix::new(3, 3);
355 assert_eq!(matrix.rows(), 3);
356 assert_eq!(matrix.cols(), 3);
357 assert_eq!(matrix.data().len(), 9);
358 }
359
360 #[test]
361 fn test_matrix_identity() {
362 let identity = Matrix::identity(3);
363 assert_eq!(identity.get(0, 0), 1.0);
364 assert_eq!(identity.get(1, 1), 1.0);
365 assert_eq!(identity.get(2, 2), 1.0);
366 assert_eq!(identity.get(0, 1), 0.0);
367 }
368
369 #[test]
370 fn test_vector_operations() {
371 let v1 = Vector::from_slice(&[1.0, 2.0, 3.0]);
372 let v2 = Vector::from_slice(&[4.0, 5.0, 6.0]);
373
374 let dot_product = v1.dot(&v2);
375 assert_eq!(dot_product, 32.0); let sum = v1.add(&v2);
378 assert_eq!(sum.data(), &[5.0, 7.0, 9.0]);
379 }
380
381 #[test]
382 fn test_matrix_vector_multiply() {
383 let matrix = Matrix::from_slice(&[1.0, 2.0, 3.0, 4.0], 2, 2);
384 let vector = Vector::from_slice(&[1.0, 2.0]);
385
386 let result = matrix.multiply_vector(&vector).unwrap();
387 assert_eq!(result.data(), &[5.0, 11.0]); }
389}