Skip to main content

yscv_tensor/
linalg.rs

1use super::error::TensorError;
2use super::tensor::Tensor;
3
4impl Tensor {
5    /// Sum of diagonal elements of a 2D square matrix.
6    pub fn trace(&self) -> Result<f32, TensorError> {
7        if self.rank() != 2 {
8            return Err(TensorError::UnsupportedOperation {
9                msg: format!("trace requires a 2D tensor, got rank {}", self.rank()),
10            });
11        }
12        let (rows, cols) = (self.shape()[0], self.shape()[1]);
13        if rows != cols {
14            return Err(TensorError::ShapeMismatch {
15                left: self.shape().to_vec(),
16                right: vec![rows, rows],
17            });
18        }
19        let data = self.data();
20        let mut sum = 0.0f32;
21        for i in 0..rows {
22            sum += data[i * cols + i];
23        }
24        Ok(sum)
25    }
26
27    /// Dot product of two 1D tensors.
28    pub fn dot(&self, rhs: &Self) -> Result<f32, TensorError> {
29        if self.rank() != 1 || rhs.rank() != 1 {
30            return Err(TensorError::UnsupportedOperation {
31                msg: format!(
32                    "dot requires two 1D tensors, got ranks {} and {}",
33                    self.rank(),
34                    rhs.rank()
35                ),
36            });
37        }
38        if self.shape() != rhs.shape() {
39            return Err(TensorError::ShapeMismatch {
40                left: self.shape().to_vec(),
41                right: rhs.shape().to_vec(),
42            });
43        }
44        let a = self.data();
45        let b = rhs.data();
46        let mut sum = 0.0f32;
47        for i in 0..a.len() {
48            sum += a[i] * b[i];
49        }
50        Ok(sum)
51    }
52
53    /// Cross product of two 3-element 1D tensors.
54    pub fn cross(&self, rhs: &Self) -> Result<Self, TensorError> {
55        if self.rank() != 1 || rhs.rank() != 1 {
56            return Err(TensorError::UnsupportedOperation {
57                msg: "cross requires two 1D tensors".into(),
58            });
59        }
60        if self.shape()[0] != 3 || rhs.shape()[0] != 3 {
61            return Err(TensorError::ShapeMismatch {
62                left: self.shape().to_vec(),
63                right: rhs.shape().to_vec(),
64            });
65        }
66        let a = self.data();
67        let b = rhs.data();
68        let result = vec![
69            a[1] * b[2] - a[2] * b[1],
70            a[2] * b[0] - a[0] * b[2],
71            a[0] * b[1] - a[1] * b[0],
72        ];
73        Tensor::from_vec(vec![3], result)
74    }
75
76    /// Lp norm of all elements. p=1 for L1, p=2 for L2.
77    pub fn norm(&self, p: f32) -> f32 {
78        let data = self.data();
79        if p == 1.0 {
80            data.iter().map(|x| x.abs()).sum()
81        } else if p == 2.0 {
82            data.iter().map(|x| x * x).sum::<f32>().sqrt()
83        } else {
84            data.iter()
85                .map(|x| x.abs().powf(p))
86                .sum::<f32>()
87                .powf(1.0 / p)
88        }
89    }
90
91    /// Determinant of a square matrix (LU-based).
92    pub fn det(&self) -> Result<f32, TensorError> {
93        if self.rank() != 2 {
94            return Err(TensorError::UnsupportedOperation {
95                msg: format!("det requires a 2D tensor, got rank {}", self.rank()),
96            });
97        }
98        let n = self.shape()[0];
99        if n != self.shape()[1] {
100            return Err(TensorError::ShapeMismatch {
101                left: self.shape().to_vec(),
102                right: vec![n, n],
103            });
104        }
105
106        // Copy data into a working matrix
107        let mut a: Vec<f32> = self.data().to_vec();
108        let mut sign = 1.0f32;
109
110        for col in 0..n {
111            // Partial pivoting: find max in column
112            let mut max_row = col;
113            let mut max_val = a[col * n + col].abs();
114            for row in (col + 1)..n {
115                let v = a[row * n + col].abs();
116                if v > max_val {
117                    max_val = v;
118                    max_row = row;
119                }
120            }
121            if max_val < 1e-12 {
122                return Ok(0.0);
123            }
124            if max_row != col {
125                // Swap rows
126                for j in 0..n {
127                    a.swap(col * n + j, max_row * n + j);
128                }
129                sign = -sign;
130            }
131            let pivot = a[col * n + col];
132            for row in (col + 1)..n {
133                let factor = a[row * n + col] / pivot;
134                for j in col..n {
135                    let val = a[col * n + j];
136                    a[row * n + j] -= factor * val;
137                }
138            }
139        }
140
141        let mut det = sign;
142        for i in 0..n {
143            det *= a[i * n + i];
144        }
145        Ok(det)
146    }
147
148    /// Inverse of a square matrix (Gauss-Jordan elimination).
149    pub fn inv(&self) -> Result<Self, TensorError> {
150        if self.rank() != 2 {
151            return Err(TensorError::UnsupportedOperation {
152                msg: format!("inv requires a 2D tensor, got rank {}", self.rank()),
153            });
154        }
155        let n = self.shape()[0];
156        if n != self.shape()[1] {
157            return Err(TensorError::ShapeMismatch {
158                left: self.shape().to_vec(),
159                right: vec![n, n],
160            });
161        }
162
163        // Augmented matrix [A | I], stored as n x 2n
164        let data = self.data();
165        let nn = 2 * n;
166        let mut aug = vec![0.0f32; n * nn];
167        for i in 0..n {
168            for j in 0..n {
169                aug[i * nn + j] = data[i * n + j];
170            }
171            aug[i * nn + n + i] = 1.0;
172        }
173
174        // Gauss-Jordan with partial pivoting
175        for col in 0..n {
176            // Find pivot
177            let mut max_row = col;
178            let mut max_val = aug[col * nn + col].abs();
179            for row in (col + 1)..n {
180                let v = aug[row * nn + col].abs();
181                if v > max_val {
182                    max_val = v;
183                    max_row = row;
184                }
185            }
186            if max_val < 1e-12 {
187                return Err(TensorError::UnsupportedOperation {
188                    msg: "matrix is singular".into(),
189                });
190            }
191            if max_row != col {
192                for j in 0..nn {
193                    aug.swap(col * nn + j, max_row * nn + j);
194                }
195            }
196
197            // Scale pivot row
198            let pivot = aug[col * nn + col];
199            for j in 0..nn {
200                aug[col * nn + j] /= pivot;
201            }
202
203            // Eliminate column in all other rows
204            for row in 0..n {
205                if row == col {
206                    continue;
207                }
208                let factor = aug[row * nn + col];
209                for j in 0..nn {
210                    let val = aug[col * nn + j];
211                    aug[row * nn + j] -= factor * val;
212                }
213            }
214        }
215
216        // Extract the right half
217        let mut result = vec![0.0f32; n * n];
218        for i in 0..n {
219            for j in 0..n {
220                result[i * n + j] = aug[i * nn + n + j];
221            }
222        }
223        Tensor::from_vec(vec![n, n], result)
224    }
225
226    /// Solve linear system Ax = b. self is A, rhs is b.
227    pub fn solve(&self, b: &Self) -> Result<Self, TensorError> {
228        if self.rank() != 2 {
229            return Err(TensorError::UnsupportedOperation {
230                msg: format!("solve requires A to be 2D, got rank {}", self.rank()),
231            });
232        }
233        let n = self.shape()[0];
234        if n != self.shape()[1] {
235            return Err(TensorError::ShapeMismatch {
236                left: self.shape().to_vec(),
237                right: vec![n, n],
238            });
239        }
240        if b.rank() != 1 || b.shape()[0] != n {
241            return Err(TensorError::ShapeMismatch {
242                left: self.shape().to_vec(),
243                right: b.shape().to_vec(),
244            });
245        }
246
247        // LU decomposition with partial pivoting
248        let data = self.data();
249        let mut a = data.to_vec();
250        let mut perm: Vec<usize> = (0..n).collect();
251
252        for col in 0..n {
253            // Partial pivoting
254            let mut max_row = col;
255            let mut max_val = a[col * n + col].abs();
256            for row in (col + 1)..n {
257                let v = a[row * n + col].abs();
258                if v > max_val {
259                    max_val = v;
260                    max_row = row;
261                }
262            }
263            if max_val < 1e-12 {
264                return Err(TensorError::UnsupportedOperation {
265                    msg: "matrix is singular".into(),
266                });
267            }
268            if max_row != col {
269                for j in 0..n {
270                    a.swap(col * n + j, max_row * n + j);
271                }
272                perm.swap(col, max_row);
273            }
274            let pivot = a[col * n + col];
275            for row in (col + 1)..n {
276                let factor = a[row * n + col] / pivot;
277                a[row * n + col] = factor; // store L factor
278                for j in (col + 1)..n {
279                    let val = a[col * n + j];
280                    a[row * n + j] -= factor * val;
281                }
282            }
283        }
284
285        // Apply permutation to b
286        let bd = b.data();
287        let mut pb = vec![0.0f32; n];
288        for i in 0..n {
289            pb[i] = bd[perm[i]];
290        }
291
292        // Forward substitution (Ly = Pb)
293        let mut y = pb;
294        for i in 1..n {
295            for j in 0..i {
296                let l_ij = a[i * n + j];
297                y[i] -= l_ij * y[j];
298            }
299        }
300
301        // Back substitution (Ux = y)
302        let mut x = y;
303        for i in (0..n).rev() {
304            for j in (i + 1)..n {
305                let u_ij = a[i * n + j];
306                x[i] -= u_ij * x[j];
307            }
308            x[i] /= a[i * n + i];
309        }
310
311        Tensor::from_vec(vec![n], x)
312    }
313
314    /// QR decomposition via Householder reflections. Returns (Q, R).
315    pub fn qr(&self) -> Result<(Self, Self), TensorError> {
316        if self.rank() != 2 {
317            return Err(TensorError::UnsupportedOperation {
318                msg: format!("qr requires a 2D tensor, got rank {}", self.rank()),
319            });
320        }
321        let m = self.shape()[0];
322        let n = self.shape()[1];
323        let data = self.data();
324
325        // Working copy of the matrix (will become R)
326        let mut r = data.to_vec();
327
328        // Q starts as identity
329        let mut q = vec![0.0f32; m * m];
330        for i in 0..m {
331            q[i * m + i] = 1.0;
332        }
333
334        let k = m.min(n);
335        for j in 0..k {
336            // Extract column j from row j..m
337            let mut col = vec![0.0f32; m - j];
338            for i in j..m {
339                col[i - j] = r[i * n + j];
340            }
341
342            // Compute the Householder vector
343            let norm_col: f32 = col.iter().map(|x| x * x).sum::<f32>().sqrt();
344            if norm_col < 1e-12 {
345                continue;
346            }
347            let sign = if col[0] >= 0.0 { 1.0 } else { -1.0 };
348            col[0] += sign * norm_col;
349
350            let norm_v: f32 = col.iter().map(|x| x * x).sum::<f32>();
351            if norm_v < 1e-24 {
352                continue;
353            }
354
355            // Apply Householder reflection to R: R = R - 2 * v * (v^T * R) / (v^T * v)
356            // Only rows j..m, cols j..n
357            for jj in j..n {
358                let mut dot = 0.0f32;
359                for i in j..m {
360                    dot += col[i - j] * r[i * n + jj];
361                }
362                let factor = 2.0 * dot / norm_v;
363                for i in j..m {
364                    r[i * n + jj] -= factor * col[i - j];
365                }
366            }
367
368            // Apply Householder reflection to Q: Q = Q - 2 * Q * v * v^T / (v^T * v)
369            // Q is m x m, we update all rows of Q, cols j..m
370            for i in 0..m {
371                let mut dot = 0.0f32;
372                for jj in j..m {
373                    dot += q[i * m + jj] * col[jj - j];
374                }
375                let factor = 2.0 * dot / norm_v;
376                for jj in j..m {
377                    q[i * m + jj] -= factor * col[jj - j];
378                }
379            }
380        }
381
382        let q_tensor = Tensor::from_vec(vec![m, m], q)?;
383        let r_tensor = Tensor::from_vec(vec![m, n], r)?;
384        Ok((q_tensor, r_tensor))
385    }
386
387    /// Cholesky decomposition of a symmetric positive-definite matrix. Returns lower triangular L.
388    pub fn cholesky(&self) -> Result<Self, TensorError> {
389        if self.rank() != 2 {
390            return Err(TensorError::UnsupportedOperation {
391                msg: format!("cholesky requires a 2D tensor, got rank {}", self.rank()),
392            });
393        }
394        let n = self.shape()[0];
395        if n != self.shape()[1] {
396            return Err(TensorError::ShapeMismatch {
397                left: self.shape().to_vec(),
398                right: vec![n, n],
399            });
400        }
401        let data = self.data();
402        let mut l = vec![0.0f32; n * n];
403
404        for i in 0..n {
405            for j in 0..=i {
406                let mut sum = 0.0f32;
407                for k in 0..j {
408                    sum += l[i * n + k] * l[j * n + k];
409                }
410                if i == j {
411                    let diag = data[i * n + i] - sum;
412                    if diag <= 0.0 {
413                        return Err(TensorError::UnsupportedOperation {
414                            msg: "matrix is not positive definite".into(),
415                        });
416                    }
417                    l[i * n + j] = diag.sqrt();
418                } else {
419                    l[i * n + j] = (data[i * n + j] - sum) / l[j * n + j];
420                }
421            }
422        }
423
424        Tensor::from_vec(vec![n, n], l)
425    }
426}