ruvector_math/tensor_networks/
tensor_train.rs

1//! Tensor Train (TT) Decomposition
2//!
3//! The Tensor Train format represents a d-dimensional tensor as:
4//!
5//! A[i1, i2, ..., id] = G1[i1] × G2[i2] × ... × Gd[id]
6//!
7//! where each Gk[ik] is an (rk-1 × rk) matrix, called a TT-core.
8//! The ranks r0 = rd = 1, so the result is a scalar.
9//!
10//! ## Complexity
11//!
12//! - Storage: O(d * n * r²) instead of O(n^d)
13//! - Dot product: O(d * r²)
14//! - Addition: O(d * n * r²) with rank doubling
15
16use super::DenseTensor;
17
18/// Tensor Train configuration
19#[derive(Debug, Clone)]
20pub struct TensorTrainConfig {
21    /// Maximum rank (0 = no limit)
22    pub max_rank: usize,
23    /// Truncation tolerance
24    pub tolerance: f64,
25}
26
27impl Default for TensorTrainConfig {
28    fn default() -> Self {
29        Self {
30            max_rank: 0,
31            tolerance: 1e-12,
32        }
33    }
34}
35
36/// A single TT-core: 3D tensor of shape (rank_left, mode_size, rank_right)
37#[derive(Debug, Clone)]
38pub struct TTCore {
39    /// Core data in row-major order: [rank_left, mode_size, rank_right]
40    pub data: Vec<f64>,
41    /// Left rank
42    pub rank_left: usize,
43    /// Mode size
44    pub mode_size: usize,
45    /// Right rank
46    pub rank_right: usize,
47}
48
49impl TTCore {
50    /// Create new TT-core
51    pub fn new(data: Vec<f64>, rank_left: usize, mode_size: usize, rank_right: usize) -> Self {
52        assert_eq!(data.len(), rank_left * mode_size * rank_right);
53        Self { data, rank_left, mode_size, rank_right }
54    }
55
56    /// Create zeros core
57    pub fn zeros(rank_left: usize, mode_size: usize, rank_right: usize) -> Self {
58        Self {
59            data: vec![0.0; rank_left * mode_size * rank_right],
60            rank_left,
61            mode_size,
62            rank_right,
63        }
64    }
65
66    /// Get the (r_l × r_r) matrix for index i
67    pub fn get_matrix(&self, i: usize) -> Vec<f64> {
68        let start = i * self.rank_left * self.rank_right;
69        let end = start + self.rank_left * self.rank_right;
70
71        // Reshape from [rank_left, mode_size, rank_right] layout
72        // to get the i-th slice
73        let mut result = vec![0.0; self.rank_left * self.rank_right];
74        for rl in 0..self.rank_left {
75            for rr in 0..self.rank_right {
76                let idx = rl * self.mode_size * self.rank_right + i * self.rank_right + rr;
77                result[rl * self.rank_right + rr] = self.data[idx];
78            }
79        }
80        result
81    }
82
83    /// Set element at (rank_left, mode, rank_right) position
84    pub fn set(&mut self, rl: usize, i: usize, rr: usize, value: f64) {
85        let idx = rl * self.mode_size * self.rank_right + i * self.rank_right + rr;
86        self.data[idx] = value;
87    }
88
89    /// Get element at (rank_left, mode, rank_right) position
90    pub fn get(&self, rl: usize, i: usize, rr: usize) -> f64 {
91        let idx = rl * self.mode_size * self.rank_right + i * self.rank_right + rr;
92        self.data[idx]
93    }
94}
95
96/// Tensor Train representation
97#[derive(Debug, Clone)]
98pub struct TensorTrain {
99    /// TT-cores
100    pub cores: Vec<TTCore>,
101    /// Original tensor shape
102    pub shape: Vec<usize>,
103    /// TT-ranks: [1, r1, r2, ..., r_{d-1}, 1]
104    pub ranks: Vec<usize>,
105}
106
107impl TensorTrain {
108    /// Create TT from cores
109    pub fn from_cores(cores: Vec<TTCore>) -> Self {
110        let shape: Vec<usize> = cores.iter().map(|c| c.mode_size).collect();
111        let mut ranks = vec![1];
112        for core in &cores {
113            ranks.push(core.rank_right);
114        }
115
116        Self { cores, shape, ranks }
117    }
118
119    /// Create rank-1 TT from vectors
120    pub fn from_vectors(vectors: Vec<Vec<f64>>) -> Self {
121        let cores: Vec<TTCore> = vectors
122            .into_iter()
123            .map(|v| {
124                let n = v.len();
125                TTCore::new(v, 1, n, 1)
126            })
127            .collect();
128
129        Self::from_cores(cores)
130    }
131
132    /// Tensor order
133    pub fn order(&self) -> usize {
134        self.shape.len()
135    }
136
137    /// Maximum TT-rank
138    pub fn max_rank(&self) -> usize {
139        self.ranks.iter().cloned().max().unwrap_or(1)
140    }
141
142    /// Total storage
143    pub fn storage(&self) -> usize {
144        self.cores.iter().map(|c| c.data.len()).sum()
145    }
146
147    /// Evaluate TT at a multi-index
148    pub fn eval(&self, indices: &[usize]) -> f64 {
149        assert_eq!(indices.len(), self.order());
150
151        // Start with 1x1 "matrix"
152        let mut result = vec![1.0];
153        let mut current_size = 1;
154
155        for (k, &idx) in indices.iter().enumerate() {
156            let core = &self.cores[k];
157            let new_size = core.rank_right;
158            let mut new_result = vec![0.0; new_size];
159
160            // Matrix-vector product
161            for rr in 0..new_size {
162                for rl in 0..current_size {
163                    new_result[rr] += result[rl] * core.get(rl, idx, rr);
164                }
165            }
166
167            result = new_result;
168            current_size = new_size;
169        }
170
171        result[0]
172    }
173
174    /// Convert to dense tensor
175    pub fn to_dense(&self) -> DenseTensor {
176        let total_size: usize = self.shape.iter().product();
177        let mut data = vec![0.0; total_size];
178
179        // Enumerate all indices
180        let mut indices = vec![0usize; self.order()];
181        for flat_idx in 0..total_size {
182            data[flat_idx] = self.eval(&indices);
183
184            // Increment indices
185            for k in (0..self.order()).rev() {
186                indices[k] += 1;
187                if indices[k] < self.shape[k] {
188                    break;
189                }
190                indices[k] = 0;
191            }
192        }
193
194        DenseTensor::new(data, self.shape.clone())
195    }
196
197    /// Dot product of two TTs
198    pub fn dot(&self, other: &TensorTrain) -> f64 {
199        assert_eq!(self.shape, other.shape);
200
201        // Accumulate product of contracted cores
202        // Result shape at step k: (r1_k × r2_k)
203        let mut z = vec![1.0]; // Start with 1×1
204        let mut z_rows = 1;
205        let mut z_cols = 1;
206
207        for k in 0..self.order() {
208            let c1 = &self.cores[k];
209            let c2 = &other.cores[k];
210            let n = c1.mode_size;
211
212            let new_rows = c1.rank_right;
213            let new_cols = c2.rank_right;
214            let mut new_z = vec![0.0; new_rows * new_cols];
215
216            // Contract over mode index and previous ranks
217            for i in 0..n {
218                for r1l in 0..z_rows {
219                    for r2l in 0..z_cols {
220                        let z_val = z[r1l * z_cols + r2l];
221
222                        for r1r in 0..c1.rank_right {
223                            for r2r in 0..c2.rank_right {
224                                new_z[r1r * new_cols + r2r] +=
225                                    z_val * c1.get(r1l, i, r1r) * c2.get(r2l, i, r2r);
226                            }
227                        }
228                    }
229                }
230            }
231
232            z = new_z;
233            z_rows = new_rows;
234            z_cols = new_cols;
235        }
236
237        z[0]
238    }
239
240    /// Frobenius norm: ||A||_F = sqrt(<A, A>)
241    pub fn frobenius_norm(&self) -> f64 {
242        self.dot(self).sqrt()
243    }
244
245    /// Add two TTs (result has rank r1 + r2)
246    pub fn add(&self, other: &TensorTrain) -> TensorTrain {
247        assert_eq!(self.shape, other.shape);
248
249        let mut new_cores = Vec::new();
250
251        for k in 0..self.order() {
252            let c1 = &self.cores[k];
253            let c2 = &other.cores[k];
254
255            let new_rl = if k == 0 { 1 } else { c1.rank_left + c2.rank_left };
256            let new_rr = if k == self.order() - 1 { 1 } else { c1.rank_right + c2.rank_right };
257            let n = c1.mode_size;
258
259            let mut new_data = vec![0.0; new_rl * n * new_rr];
260            let mut new_core = TTCore::new(new_data.clone(), new_rl, n, new_rr);
261
262            for i in 0..n {
263                if k == 0 {
264                    // First core: [c1, c2] horizontally
265                    for rr1 in 0..c1.rank_right {
266                        new_core.set(0, i, rr1, c1.get(0, i, rr1));
267                    }
268                    for rr2 in 0..c2.rank_right {
269                        new_core.set(0, i, c1.rank_right + rr2, c2.get(0, i, rr2));
270                    }
271                } else if k == self.order() - 1 {
272                    // Last core: [c1; c2] vertically
273                    for rl1 in 0..c1.rank_left {
274                        new_core.set(rl1, i, 0, c1.get(rl1, i, 0));
275                    }
276                    for rl2 in 0..c2.rank_left {
277                        new_core.set(c1.rank_left + rl2, i, 0, c2.get(rl2, i, 0));
278                    }
279                } else {
280                    // Middle core: block diagonal
281                    for rl1 in 0..c1.rank_left {
282                        for rr1 in 0..c1.rank_right {
283                            new_core.set(rl1, i, rr1, c1.get(rl1, i, rr1));
284                        }
285                    }
286                    for rl2 in 0..c2.rank_left {
287                        for rr2 in 0..c2.rank_right {
288                            new_core.set(c1.rank_left + rl2, i, c1.rank_right + rr2, c2.get(rl2, i, rr2));
289                        }
290                    }
291                }
292            }
293
294            new_cores.push(new_core);
295        }
296
297        TensorTrain::from_cores(new_cores)
298    }
299
300    /// Scale by a constant
301    pub fn scale(&self, alpha: f64) -> TensorTrain {
302        let mut new_cores = self.cores.clone();
303
304        // Scale first core only
305        for val in new_cores[0].data.iter_mut() {
306            *val *= alpha;
307        }
308
309        TensorTrain::from_cores(new_cores)
310    }
311
312    /// TT-SVD decomposition from dense tensor
313    pub fn from_dense(tensor: &DenseTensor, config: &TensorTrainConfig) -> Self {
314        let d = tensor.order();
315        if d == 0 {
316            return TensorTrain::from_cores(vec![]);
317        }
318
319        let mut cores = Vec::new();
320        let mut c = tensor.data.clone();
321        let mut remaining_shape = tensor.shape.clone();
322        let mut left_rank = 1usize;
323
324        for k in 0..d - 1 {
325            let n_k = remaining_shape[0];
326            let rest_size: usize = remaining_shape[1..].iter().product();
327
328            // Reshape C to (left_rank * n_k) × rest_size
329            let rows = left_rank * n_k;
330            let cols = rest_size;
331
332            // Simple SVD via power iteration (for demonstration)
333            let (u, s, vt, new_rank) = simple_svd(&c, rows, cols, config);
334
335            // Create core from U
336            let core = TTCore::new(u, left_rank, n_k, new_rank);
337            cores.push(core);
338
339            // C = S * Vt for next iteration
340            c = Vec::with_capacity(new_rank * cols);
341            for i in 0..new_rank {
342                for j in 0..cols {
343                    c.push(s[i] * vt[i * cols + j]);
344                }
345            }
346
347            left_rank = new_rank;
348            remaining_shape.remove(0);
349        }
350
351        // Last core
352        let n_d = remaining_shape[0];
353        let last_core = TTCore::new(c, left_rank, n_d, 1);
354        cores.push(last_core);
355
356        TensorTrain::from_cores(cores)
357    }
358}
359
360/// Simple truncated SVD using power iteration
361/// Returns (U, S, Vt, rank)
362fn simple_svd(a: &[f64], rows: usize, cols: usize, config: &TensorTrainConfig) -> (Vec<f64>, Vec<f64>, Vec<f64>, usize) {
363    let max_rank = if config.max_rank > 0 {
364        config.max_rank.min(rows).min(cols)
365    } else {
366        rows.min(cols)
367    };
368
369    let mut u = Vec::new();
370    let mut s = Vec::new();
371    let mut vt = Vec::new();
372
373    let mut a_residual = a.to_vec();
374
375    for _ in 0..max_rank {
376        // Power iteration to find top singular vector
377        let (sigma, u_vec, v_vec) = power_iteration(&a_residual, rows, cols, 20);
378
379        if sigma < config.tolerance {
380            break;
381        }
382
383        s.push(sigma);
384        u.extend(u_vec.iter());
385        vt.extend(v_vec.iter());
386
387        // Deflate: A = A - sigma * u * v^T
388        for i in 0..rows {
389            for j in 0..cols {
390                a_residual[i * cols + j] -= sigma * u_vec[i] * v_vec[j];
391            }
392        }
393    }
394
395    let rank = s.len();
396    (u, s, vt, rank.max(1))
397}
398
399/// Power iteration for largest singular value
400fn power_iteration(a: &[f64], rows: usize, cols: usize, max_iter: usize) -> (f64, Vec<f64>, Vec<f64>) {
401    // Initialize random v
402    let mut v: Vec<f64> = (0..cols).map(|i| ((i * 2654435769) as f64 / 4294967296.0) * 2.0 - 1.0).collect();
403    normalize(&mut v);
404
405    let mut u = vec![0.0; rows];
406
407    for _ in 0..max_iter {
408        // u = A * v
409        for i in 0..rows {
410            u[i] = 0.0;
411            for j in 0..cols {
412                u[i] += a[i * cols + j] * v[j];
413            }
414        }
415        normalize(&mut u);
416
417        // v = A^T * u
418        for j in 0..cols {
419            v[j] = 0.0;
420            for i in 0..rows {
421                v[j] += a[i * cols + j] * u[i];
422            }
423        }
424        normalize(&mut v);
425    }
426
427    // Compute singular value
428    let mut av = vec![0.0; rows];
429    for i in 0..rows {
430        for j in 0..cols {
431            av[i] += a[i * cols + j] * v[j];
432        }
433    }
434    let sigma: f64 = u.iter().zip(av.iter()).map(|(ui, avi)| ui * avi).sum();
435
436    (sigma.abs(), u, v)
437}
438
439fn normalize(v: &mut [f64]) {
440    let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
441    if norm > 1e-15 {
442        for x in v.iter_mut() {
443            *x /= norm;
444        }
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451
452    #[test]
453    fn test_tt_eval() {
454        // Rank-1 TT representing outer product of [1,2] and [3,4]
455        let v1 = vec![1.0, 2.0];
456        let v2 = vec![3.0, 4.0];
457        let tt = TensorTrain::from_vectors(vec![v1, v2]);
458
459        // Should equal v1[i] * v2[j]
460        assert!((tt.eval(&[0, 0]) - 3.0).abs() < 1e-10);
461        assert!((tt.eval(&[0, 1]) - 4.0).abs() < 1e-10);
462        assert!((tt.eval(&[1, 0]) - 6.0).abs() < 1e-10);
463        assert!((tt.eval(&[1, 1]) - 8.0).abs() < 1e-10);
464    }
465
466    #[test]
467    fn test_tt_dot() {
468        let v1 = vec![1.0, 2.0];
469        let v2 = vec![3.0, 4.0];
470        let tt = TensorTrain::from_vectors(vec![v1, v2]);
471
472        // <A, A> = sum of squares
473        let norm_sq = tt.dot(&tt);
474        // Elements: 3, 4, 6, 8 -> sum of squares = 9 + 16 + 36 + 64 = 125
475        assert!((norm_sq - 125.0).abs() < 1e-10);
476    }
477
478    #[test]
479    fn test_tt_from_dense() {
480        let tensor = DenseTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
481        let tt = TensorTrain::from_dense(&tensor, &TensorTrainConfig::default());
482
483        // Check reconstruction
484        let reconstructed = tt.to_dense();
485        let error: f64 = tensor.data.iter().zip(reconstructed.data.iter())
486            .map(|(a, b)| (a - b).powi(2))
487            .sum::<f64>()
488            .sqrt();
489
490        assert!(error < 1e-6);
491    }
492
493    #[test]
494    fn test_tt_add() {
495        let v1 = vec![1.0, 2.0];
496        let v2 = vec![3.0, 4.0];
497        let tt1 = TensorTrain::from_vectors(vec![v1.clone(), v2.clone()]);
498        let tt2 = TensorTrain::from_vectors(vec![v1, v2]);
499
500        let sum = tt1.add(&tt2);
501
502        // Should be 2 * tt1
503        assert!((sum.eval(&[0, 0]) - 6.0).abs() < 1e-10);
504        assert!((sum.eval(&[1, 1]) - 16.0).abs() < 1e-10);
505    }
506}