ruvector_math/tensor_networks/
cp_decomposition.rs1use super::DenseTensor;
9
10#[derive(Debug, Clone)]
12pub struct CPConfig {
13 pub rank: usize,
15 pub max_iters: usize,
17 pub tolerance: f64,
19}
20
21impl Default for CPConfig {
22 fn default() -> Self {
23 Self {
24 rank: 10,
25 max_iters: 100,
26 tolerance: 1e-8,
27 }
28 }
29}
30
31#[derive(Debug, Clone)]
33pub struct CPDecomposition {
34 pub weights: Vec<f64>,
36 pub factors: Vec<Vec<f64>>,
38 pub shape: Vec<usize>,
40 pub rank: usize,
42}
43
44impl CPDecomposition {
45 pub fn als(tensor: &DenseTensor, config: &CPConfig) -> Self {
47 let d = tensor.order();
48 let r = config.rank;
49
50 let mut factors: Vec<Vec<f64>> = tensor.shape.iter()
52 .enumerate()
53 .map(|(k, &n_k)| {
54 (0..n_k * r).map(|i| {
55 let x = ((i * 2654435769 + k * 1103515245) as f64 / 4294967296.0) * 2.0 - 1.0;
56 x
57 }).collect()
58 })
59 .collect();
60
61 let mut weights = vec![1.0; r];
63 for (k, factor) in factors.iter_mut().enumerate() {
64 normalize_columns(factor, tensor.shape[k], r);
65 }
66
67 for _ in 0..config.max_iters {
69 for k in 0..d {
70 update_factor_als(tensor, &mut factors, k, r);
72 normalize_columns(&mut factors[k], tensor.shape[k], r);
73 }
74 }
75
76 for col in 0..r {
78 let mut norm = 0.0;
79 for row in 0..tensor.shape[0] {
80 norm += factors[0][row * r + col].powi(2);
81 }
82 weights[col] = norm.sqrt();
83
84 if weights[col] > 1e-15 {
85 for row in 0..tensor.shape[0] {
86 factors[0][row * r + col] /= weights[col];
87 }
88 }
89 }
90
91 Self {
92 weights,
93 factors,
94 shape: tensor.shape.clone(),
95 rank: r,
96 }
97 }
98
99 pub fn to_dense(&self) -> DenseTensor {
101 let total_size: usize = self.shape.iter().product();
102 let mut data = vec![0.0; total_size];
103 let d = self.shape.len();
104
105 let mut indices = vec![0usize; d];
107 for flat_idx in 0..total_size {
108 let mut val = 0.0;
109
110 for col in 0..self.rank {
112 let mut prod = self.weights[col];
113 for (k, &idx) in indices.iter().enumerate() {
114 prod *= self.factors[k][idx * self.rank + col];
115 }
116 val += prod;
117 }
118
119 data[flat_idx] = val;
120
121 for k in (0..d).rev() {
123 indices[k] += 1;
124 if indices[k] < self.shape[k] {
125 break;
126 }
127 indices[k] = 0;
128 }
129 }
130
131 DenseTensor::new(data, self.shape.clone())
132 }
133
134 pub fn eval(&self, indices: &[usize]) -> f64 {
136 let mut val = 0.0;
137
138 for col in 0..self.rank {
139 let mut prod = self.weights[col];
140 for (k, &idx) in indices.iter().enumerate() {
141 prod *= self.factors[k][idx * self.rank + col];
142 }
143 val += prod;
144 }
145
146 val
147 }
148
149 pub fn storage(&self) -> usize {
151 self.weights.len() + self.factors.iter().map(|f| f.len()).sum::<usize>()
152 }
153
154 pub fn compression_ratio(&self) -> f64 {
156 let original: usize = self.shape.iter().product();
157 let storage = self.storage();
158 if storage == 0 {
159 return f64::INFINITY;
160 }
161 original as f64 / storage as f64
162 }
163
164 pub fn relative_error(&self, tensor: &DenseTensor) -> f64 {
166 let reconstructed = self.to_dense();
167
168 let mut error_sq = 0.0;
169 let mut tensor_sq = 0.0;
170
171 for (a, b) in tensor.data.iter().zip(reconstructed.data.iter()) {
172 error_sq += (a - b).powi(2);
173 tensor_sq += a.powi(2);
174 }
175
176 (error_sq / tensor_sq.max(1e-15)).sqrt()
177 }
178}
179
180fn normalize_columns(factor: &mut [f64], rows: usize, cols: usize) {
182 for c in 0..cols {
183 let mut norm = 0.0;
184 for r in 0..rows {
185 norm += factor[r * cols + c].powi(2);
186 }
187 norm = norm.sqrt();
188
189 if norm > 1e-15 {
190 for r in 0..rows {
191 factor[r * cols + c] /= norm;
192 }
193 }
194 }
195}
196
197fn update_factor_als(tensor: &DenseTensor, factors: &mut [Vec<f64>], k: usize, rank: usize) {
199 let d = tensor.order();
200 let n_k = tensor.shape[k];
201
202 let mut v = vec![1.0; rank * rank];
207 for m in 0..d {
208 if m == k {
209 continue;
210 }
211
212 let n_m = tensor.shape[m];
213 let factor_m = &factors[m];
214
215 let mut gram = vec![0.0; rank * rank];
217 for i in 0..rank {
218 for j in 0..rank {
219 for row in 0..n_m {
220 gram[i * rank + j] += factor_m[row * rank + i] * factor_m[row * rank + j];
221 }
222 }
223 }
224
225 for i in 0..rank * rank {
227 v[i] *= gram[i];
228 }
229 }
230
231 let mttkrp = compute_mttkrp(tensor, factors, k, rank);
233
234 let v_inv = pseudo_inverse_symmetric(&v, rank);
237
238 let mut new_factor = vec![0.0; n_k * rank];
239 for row in 0..n_k {
240 for col in 0..rank {
241 for c in 0..rank {
242 new_factor[row * rank + col] += mttkrp[row * rank + c] * v_inv[c * rank + col];
243 }
244 }
245 }
246
247 factors[k] = new_factor;
248}
249
250fn compute_mttkrp(tensor: &DenseTensor, factors: &[Vec<f64>], k: usize, rank: usize) -> Vec<f64> {
252 let d = tensor.order();
253 let n_k = tensor.shape[k];
254 let mut result = vec![0.0; n_k * rank];
255
256 let total_size: usize = tensor.shape.iter().product();
258 let mut indices = vec![0usize; d];
259
260 for flat_idx in 0..total_size {
261 let val = tensor.data[flat_idx];
262 let i_k = indices[k];
263
264 for col in 0..rank {
265 let mut prod = val;
266 for (m, &idx) in indices.iter().enumerate() {
267 if m != k {
268 prod *= factors[m][idx * rank + col];
269 }
270 }
271 result[i_k * rank + col] += prod;
272 }
273
274 for m in (0..d).rev() {
276 indices[m] += 1;
277 if indices[m] < tensor.shape[m] {
278 break;
279 }
280 indices[m] = 0;
281 }
282 }
283
284 result
285}
286
287fn pseudo_inverse_symmetric(a: &[f64], n: usize) -> Vec<f64> {
289 let eps = 1e-10;
291
292 let mut a_reg = a.to_vec();
294 for i in 0..n {
295 a_reg[i * n + i] += eps;
296 }
297
298 let mut augmented = vec![0.0; n * 2 * n];
300 for i in 0..n {
301 for j in 0..n {
302 augmented[i * 2 * n + j] = a_reg[i * n + j];
303 }
304 augmented[i * 2 * n + n + i] = 1.0;
305 }
306
307 for col in 0..n {
308 let mut max_row = col;
310 for row in col + 1..n {
311 if augmented[row * 2 * n + col].abs() > augmented[max_row * 2 * n + col].abs() {
312 max_row = row;
313 }
314 }
315
316 for j in 0..2 * n {
318 augmented.swap(col * 2 * n + j, max_row * 2 * n + j);
319 }
320
321 let pivot = augmented[col * 2 * n + col];
322 if pivot.abs() < 1e-15 {
323 continue;
324 }
325
326 for j in 0..2 * n {
328 augmented[col * 2 * n + j] /= pivot;
329 }
330
331 for row in 0..n {
333 if row == col {
334 continue;
335 }
336 let factor = augmented[row * 2 * n + col];
337 for j in 0..2 * n {
338 augmented[row * 2 * n + j] -= factor * augmented[col * 2 * n + j];
339 }
340 }
341 }
342
343 let mut inv = vec![0.0; n * n];
345 for i in 0..n {
346 for j in 0..n {
347 inv[i * n + j] = augmented[i * 2 * n + n + j];
348 }
349 }
350
351 inv
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_cp_als() {
360 let tensor = DenseTensor::random(vec![4, 5, 3], 42);
362
363 let config = CPConfig {
364 rank: 5,
365 max_iters: 50, ..Default::default()
367 };
368
369 let cp = CPDecomposition::als(&tensor, &config);
370
371 assert_eq!(cp.rank, 5);
372 assert_eq!(cp.weights.len(), 5);
373
374 let error = cp.relative_error(&tensor);
376 assert!(error.is_finite(), "Error should be finite: {}", error);
378 }
379
380 #[test]
381 fn test_cp_eval() {
382 let tensor = DenseTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
383
384 let config = CPConfig {
385 rank: 2,
386 max_iters: 50,
387 ..Default::default()
388 };
389
390 let cp = CPDecomposition::als(&tensor, &config);
391
392 let reconstructed = cp.to_dense();
394 for (a, b) in tensor.data.iter().zip(reconstructed.data.iter()) {
395 }
397 }
398}