ruvector_math/tensor_networks/
cp_decomposition.rs1use super::DenseTensor;
9
10#[derive(Debug, Clone)]
12pub struct CPConfig {
13 pub rank: usize,
15 pub max_iters: usize,
17 pub tolerance: f64,
19}
20
21impl Default for CPConfig {
22 fn default() -> Self {
23 Self {
24 rank: 10,
25 max_iters: 100,
26 tolerance: 1e-8,
27 }
28 }
29}
30
31#[derive(Debug, Clone)]
33pub struct CPDecomposition {
34 pub weights: Vec<f64>,
36 pub factors: Vec<Vec<f64>>,
38 pub shape: Vec<usize>,
40 pub rank: usize,
42}
43
44impl CPDecomposition {
45 pub fn als(tensor: &DenseTensor, config: &CPConfig) -> Self {
47 let d = tensor.order();
48 let r = config.rank;
49
50 let mut factors: Vec<Vec<f64>> = tensor
52 .shape
53 .iter()
54 .enumerate()
55 .map(|(k, &n_k)| {
56 (0..n_k * r)
57 .map(|i| {
58 let x =
59 ((i * 2654435769 + k * 1103515245) as f64 / 4294967296.0) * 2.0 - 1.0;
60 x
61 })
62 .collect()
63 })
64 .collect();
65
66 let mut weights = vec![1.0; r];
68 for (k, factor) in factors.iter_mut().enumerate() {
69 normalize_columns(factor, tensor.shape[k], r);
70 }
71
72 for _ in 0..config.max_iters {
74 for k in 0..d {
75 update_factor_als(tensor, &mut factors, k, r);
77 normalize_columns(&mut factors[k], tensor.shape[k], r);
78 }
79 }
80
81 for col in 0..r {
83 let mut norm = 0.0;
84 for row in 0..tensor.shape[0] {
85 norm += factors[0][row * r + col].powi(2);
86 }
87 weights[col] = norm.sqrt();
88
89 if weights[col] > 1e-15 {
90 for row in 0..tensor.shape[0] {
91 factors[0][row * r + col] /= weights[col];
92 }
93 }
94 }
95
96 Self {
97 weights,
98 factors,
99 shape: tensor.shape.clone(),
100 rank: r,
101 }
102 }
103
104 pub fn to_dense(&self) -> DenseTensor {
106 let total_size: usize = self.shape.iter().product();
107 let mut data = vec![0.0; total_size];
108 let d = self.shape.len();
109
110 let mut indices = vec![0usize; d];
112 for flat_idx in 0..total_size {
113 let mut val = 0.0;
114
115 for col in 0..self.rank {
117 let mut prod = self.weights[col];
118 for (k, &idx) in indices.iter().enumerate() {
119 prod *= self.factors[k][idx * self.rank + col];
120 }
121 val += prod;
122 }
123
124 data[flat_idx] = val;
125
126 for k in (0..d).rev() {
128 indices[k] += 1;
129 if indices[k] < self.shape[k] {
130 break;
131 }
132 indices[k] = 0;
133 }
134 }
135
136 DenseTensor::new(data, self.shape.clone())
137 }
138
139 pub fn eval(&self, indices: &[usize]) -> f64 {
141 let mut val = 0.0;
142
143 for col in 0..self.rank {
144 let mut prod = self.weights[col];
145 for (k, &idx) in indices.iter().enumerate() {
146 prod *= self.factors[k][idx * self.rank + col];
147 }
148 val += prod;
149 }
150
151 val
152 }
153
154 pub fn storage(&self) -> usize {
156 self.weights.len() + self.factors.iter().map(|f| f.len()).sum::<usize>()
157 }
158
159 pub fn compression_ratio(&self) -> f64 {
161 let original: usize = self.shape.iter().product();
162 let storage = self.storage();
163 if storage == 0 {
164 return f64::INFINITY;
165 }
166 original as f64 / storage as f64
167 }
168
169 pub fn relative_error(&self, tensor: &DenseTensor) -> f64 {
171 let reconstructed = self.to_dense();
172
173 let mut error_sq = 0.0;
174 let mut tensor_sq = 0.0;
175
176 for (a, b) in tensor.data.iter().zip(reconstructed.data.iter()) {
177 error_sq += (a - b).powi(2);
178 tensor_sq += a.powi(2);
179 }
180
181 (error_sq / tensor_sq.max(1e-15)).sqrt()
182 }
183}
184
185fn normalize_columns(factor: &mut [f64], rows: usize, cols: usize) {
187 for c in 0..cols {
188 let mut norm = 0.0;
189 for r in 0..rows {
190 norm += factor[r * cols + c].powi(2);
191 }
192 norm = norm.sqrt();
193
194 if norm > 1e-15 {
195 for r in 0..rows {
196 factor[r * cols + c] /= norm;
197 }
198 }
199 }
200}
201
202fn update_factor_als(tensor: &DenseTensor, factors: &mut [Vec<f64>], k: usize, rank: usize) {
204 let d = tensor.order();
205 let n_k = tensor.shape[k];
206
207 let mut v = vec![1.0; rank * rank];
212 for m in 0..d {
213 if m == k {
214 continue;
215 }
216
217 let n_m = tensor.shape[m];
218 let factor_m = &factors[m];
219
220 let mut gram = vec![0.0; rank * rank];
222 for i in 0..rank {
223 for j in 0..rank {
224 for row in 0..n_m {
225 gram[i * rank + j] += factor_m[row * rank + i] * factor_m[row * rank + j];
226 }
227 }
228 }
229
230 for i in 0..rank * rank {
232 v[i] *= gram[i];
233 }
234 }
235
236 let mttkrp = compute_mttkrp(tensor, factors, k, rank);
238
239 let v_inv = pseudo_inverse_symmetric(&v, rank);
242
243 let mut new_factor = vec![0.0; n_k * rank];
244 for row in 0..n_k {
245 for col in 0..rank {
246 for c in 0..rank {
247 new_factor[row * rank + col] += mttkrp[row * rank + c] * v_inv[c * rank + col];
248 }
249 }
250 }
251
252 factors[k] = new_factor;
253}
254
255fn compute_mttkrp(tensor: &DenseTensor, factors: &[Vec<f64>], k: usize, rank: usize) -> Vec<f64> {
257 let d = tensor.order();
258 let n_k = tensor.shape[k];
259 let mut result = vec![0.0; n_k * rank];
260
261 let total_size: usize = tensor.shape.iter().product();
263 let mut indices = vec![0usize; d];
264
265 for flat_idx in 0..total_size {
266 let val = tensor.data[flat_idx];
267 let i_k = indices[k];
268
269 for col in 0..rank {
270 let mut prod = val;
271 for (m, &idx) in indices.iter().enumerate() {
272 if m != k {
273 prod *= factors[m][idx * rank + col];
274 }
275 }
276 result[i_k * rank + col] += prod;
277 }
278
279 for m in (0..d).rev() {
281 indices[m] += 1;
282 if indices[m] < tensor.shape[m] {
283 break;
284 }
285 indices[m] = 0;
286 }
287 }
288
289 result
290}
291
292fn pseudo_inverse_symmetric(a: &[f64], n: usize) -> Vec<f64> {
294 let eps = 1e-10;
296
297 let mut a_reg = a.to_vec();
299 for i in 0..n {
300 a_reg[i * n + i] += eps;
301 }
302
303 let mut augmented = vec![0.0; n * 2 * n];
305 for i in 0..n {
306 for j in 0..n {
307 augmented[i * 2 * n + j] = a_reg[i * n + j];
308 }
309 augmented[i * 2 * n + n + i] = 1.0;
310 }
311
312 for col in 0..n {
313 let mut max_row = col;
315 for row in col + 1..n {
316 if augmented[row * 2 * n + col].abs() > augmented[max_row * 2 * n + col].abs() {
317 max_row = row;
318 }
319 }
320
321 for j in 0..2 * n {
323 augmented.swap(col * 2 * n + j, max_row * 2 * n + j);
324 }
325
326 let pivot = augmented[col * 2 * n + col];
327 if pivot.abs() < 1e-15 {
328 continue;
329 }
330
331 for j in 0..2 * n {
333 augmented[col * 2 * n + j] /= pivot;
334 }
335
336 for row in 0..n {
338 if row == col {
339 continue;
340 }
341 let factor = augmented[row * 2 * n + col];
342 for j in 0..2 * n {
343 augmented[row * 2 * n + j] -= factor * augmented[col * 2 * n + j];
344 }
345 }
346 }
347
348 let mut inv = vec![0.0; n * n];
350 for i in 0..n {
351 for j in 0..n {
352 inv[i * n + j] = augmented[i * 2 * n + n + j];
353 }
354 }
355
356 inv
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
364 fn test_cp_als() {
365 let tensor = DenseTensor::random(vec![4, 5, 3], 42);
367
368 let config = CPConfig {
369 rank: 5,
370 max_iters: 50, ..Default::default()
372 };
373
374 let cp = CPDecomposition::als(&tensor, &config);
375
376 assert_eq!(cp.rank, 5);
377 assert_eq!(cp.weights.len(), 5);
378
379 let error = cp.relative_error(&tensor);
381 assert!(error.is_finite(), "Error should be finite: {}", error);
383 }
384
385 #[test]
386 fn test_cp_eval() {
387 let tensor = DenseTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
388
389 let config = CPConfig {
390 rank: 2,
391 max_iters: 50,
392 ..Default::default()
393 };
394
395 let cp = CPDecomposition::als(&tensor, &config);
396
397 let reconstructed = cp.to_dense();
399 for (a, b) in tensor.data.iter().zip(reconstructed.data.iter()) {
400 }
402 }
403}