Skip to main content

ruvector_math/tensor_networks/
tensor_train.rs

1//! Tensor Train (TT) Decomposition
2//!
3//! The Tensor Train format represents a d-dimensional tensor as:
4//!
5//! A[i1, i2, ..., id] = G1[i1] × G2[i2] × ... × Gd[id]
6//!
7//! where each Gk[ik] is an (rk-1 × rk) matrix, called a TT-core.
8//! The ranks r0 = rd = 1, so the result is a scalar.
9//!
10//! ## Complexity
11//!
12//! - Storage: O(d * n * r²) instead of O(n^d)
13//! - Dot product: O(d * r²)
14//! - Addition: O(d * n * r²) with rank doubling
15
16use super::DenseTensor;
17
18/// Tensor Train configuration
19#[derive(Debug, Clone)]
20pub struct TensorTrainConfig {
21    /// Maximum rank (0 = no limit)
22    pub max_rank: usize,
23    /// Truncation tolerance
24    pub tolerance: f64,
25}
26
27impl Default for TensorTrainConfig {
28    fn default() -> Self {
29        Self {
30            max_rank: 0,
31            tolerance: 1e-12,
32        }
33    }
34}
35
36/// A single TT-core: 3D tensor of shape (rank_left, mode_size, rank_right)
37#[derive(Debug, Clone)]
38pub struct TTCore {
39    /// Core data in row-major order: [rank_left, mode_size, rank_right]
40    pub data: Vec<f64>,
41    /// Left rank
42    pub rank_left: usize,
43    /// Mode size
44    pub mode_size: usize,
45    /// Right rank
46    pub rank_right: usize,
47}
48
49impl TTCore {
50    /// Create new TT-core
51    pub fn new(data: Vec<f64>, rank_left: usize, mode_size: usize, rank_right: usize) -> Self {
52        assert_eq!(data.len(), rank_left * mode_size * rank_right);
53        Self {
54            data,
55            rank_left,
56            mode_size,
57            rank_right,
58        }
59    }
60
61    /// Create zeros core
62    pub fn zeros(rank_left: usize, mode_size: usize, rank_right: usize) -> Self {
63        Self {
64            data: vec![0.0; rank_left * mode_size * rank_right],
65            rank_left,
66            mode_size,
67            rank_right,
68        }
69    }
70
71    /// Get the (r_l × r_r) matrix for index i
72    pub fn get_matrix(&self, i: usize) -> Vec<f64> {
73        let start = i * self.rank_left * self.rank_right;
74        let end = start + self.rank_left * self.rank_right;
75
76        // Reshape from [rank_left, mode_size, rank_right] layout
77        // to get the i-th slice
78        let mut result = vec![0.0; self.rank_left * self.rank_right];
79        for rl in 0..self.rank_left {
80            for rr in 0..self.rank_right {
81                let idx = rl * self.mode_size * self.rank_right + i * self.rank_right + rr;
82                result[rl * self.rank_right + rr] = self.data[idx];
83            }
84        }
85        result
86    }
87
88    /// Set element at (rank_left, mode, rank_right) position
89    pub fn set(&mut self, rl: usize, i: usize, rr: usize, value: f64) {
90        let idx = rl * self.mode_size * self.rank_right + i * self.rank_right + rr;
91        self.data[idx] = value;
92    }
93
94    /// Get element at (rank_left, mode, rank_right) position
95    pub fn get(&self, rl: usize, i: usize, rr: usize) -> f64 {
96        let idx = rl * self.mode_size * self.rank_right + i * self.rank_right + rr;
97        self.data[idx]
98    }
99}
100
101/// Tensor Train representation
102#[derive(Debug, Clone)]
103pub struct TensorTrain {
104    /// TT-cores
105    pub cores: Vec<TTCore>,
106    /// Original tensor shape
107    pub shape: Vec<usize>,
108    /// TT-ranks: [1, r1, r2, ..., r_{d-1}, 1]
109    pub ranks: Vec<usize>,
110}
111
112impl TensorTrain {
113    /// Create TT from cores
114    pub fn from_cores(cores: Vec<TTCore>) -> Self {
115        let shape: Vec<usize> = cores.iter().map(|c| c.mode_size).collect();
116        let mut ranks = vec![1];
117        for core in &cores {
118            ranks.push(core.rank_right);
119        }
120
121        Self {
122            cores,
123            shape,
124            ranks,
125        }
126    }
127
128    /// Create rank-1 TT from vectors
129    pub fn from_vectors(vectors: Vec<Vec<f64>>) -> Self {
130        let cores: Vec<TTCore> = vectors
131            .into_iter()
132            .map(|v| {
133                let n = v.len();
134                TTCore::new(v, 1, n, 1)
135            })
136            .collect();
137
138        Self::from_cores(cores)
139    }
140
141    /// Tensor order
142    pub fn order(&self) -> usize {
143        self.shape.len()
144    }
145
146    /// Maximum TT-rank
147    pub fn max_rank(&self) -> usize {
148        self.ranks.iter().cloned().max().unwrap_or(1)
149    }
150
151    /// Total storage
152    pub fn storage(&self) -> usize {
153        self.cores.iter().map(|c| c.data.len()).sum()
154    }
155
156    /// Evaluate TT at a multi-index
157    pub fn eval(&self, indices: &[usize]) -> f64 {
158        assert_eq!(indices.len(), self.order());
159
160        // Start with 1x1 "matrix"
161        let mut result = vec![1.0];
162        let mut current_size = 1;
163
164        for (k, &idx) in indices.iter().enumerate() {
165            let core = &self.cores[k];
166            let new_size = core.rank_right;
167            let mut new_result = vec![0.0; new_size];
168
169            // Matrix-vector product
170            for rr in 0..new_size {
171                for rl in 0..current_size {
172                    new_result[rr] += result[rl] * core.get(rl, idx, rr);
173                }
174            }
175
176            result = new_result;
177            current_size = new_size;
178        }
179
180        result[0]
181    }
182
183    /// Convert to dense tensor
184    pub fn to_dense(&self) -> DenseTensor {
185        let total_size: usize = self.shape.iter().product();
186        let mut data = vec![0.0; total_size];
187
188        // Enumerate all indices
189        let mut indices = vec![0usize; self.order()];
190        for flat_idx in 0..total_size {
191            data[flat_idx] = self.eval(&indices);
192
193            // Increment indices
194            for k in (0..self.order()).rev() {
195                indices[k] += 1;
196                if indices[k] < self.shape[k] {
197                    break;
198                }
199                indices[k] = 0;
200            }
201        }
202
203        DenseTensor::new(data, self.shape.clone())
204    }
205
206    /// Dot product of two TTs
207    pub fn dot(&self, other: &TensorTrain) -> f64 {
208        assert_eq!(self.shape, other.shape);
209
210        // Accumulate product of contracted cores
211        // Result shape at step k: (r1_k × r2_k)
212        let mut z = vec![1.0]; // Start with 1×1
213        let mut z_rows = 1;
214        let mut z_cols = 1;
215
216        for k in 0..self.order() {
217            let c1 = &self.cores[k];
218            let c2 = &other.cores[k];
219            let n = c1.mode_size;
220
221            let new_rows = c1.rank_right;
222            let new_cols = c2.rank_right;
223            let mut new_z = vec![0.0; new_rows * new_cols];
224
225            // Contract over mode index and previous ranks
226            for i in 0..n {
227                for r1l in 0..z_rows {
228                    for r2l in 0..z_cols {
229                        let z_val = z[r1l * z_cols + r2l];
230
231                        for r1r in 0..c1.rank_right {
232                            for r2r in 0..c2.rank_right {
233                                new_z[r1r * new_cols + r2r] +=
234                                    z_val * c1.get(r1l, i, r1r) * c2.get(r2l, i, r2r);
235                            }
236                        }
237                    }
238                }
239            }
240
241            z = new_z;
242            z_rows = new_rows;
243            z_cols = new_cols;
244        }
245
246        z[0]
247    }
248
249    /// Frobenius norm: ||A||_F = sqrt(<A, A>)
250    pub fn frobenius_norm(&self) -> f64 {
251        self.dot(self).sqrt()
252    }
253
254    /// Add two TTs (result has rank r1 + r2)
255    pub fn add(&self, other: &TensorTrain) -> TensorTrain {
256        assert_eq!(self.shape, other.shape);
257
258        let mut new_cores = Vec::new();
259
260        for k in 0..self.order() {
261            let c1 = &self.cores[k];
262            let c2 = &other.cores[k];
263
264            let new_rl = if k == 0 {
265                1
266            } else {
267                c1.rank_left + c2.rank_left
268            };
269            let new_rr = if k == self.order() - 1 {
270                1
271            } else {
272                c1.rank_right + c2.rank_right
273            };
274            let n = c1.mode_size;
275
276            let mut new_data = vec![0.0; new_rl * n * new_rr];
277            let mut new_core = TTCore::new(new_data.clone(), new_rl, n, new_rr);
278
279            for i in 0..n {
280                if k == 0 {
281                    // First core: [c1, c2] horizontally
282                    for rr1 in 0..c1.rank_right {
283                        new_core.set(0, i, rr1, c1.get(0, i, rr1));
284                    }
285                    for rr2 in 0..c2.rank_right {
286                        new_core.set(0, i, c1.rank_right + rr2, c2.get(0, i, rr2));
287                    }
288                } else if k == self.order() - 1 {
289                    // Last core: [c1; c2] vertically
290                    for rl1 in 0..c1.rank_left {
291                        new_core.set(rl1, i, 0, c1.get(rl1, i, 0));
292                    }
293                    for rl2 in 0..c2.rank_left {
294                        new_core.set(c1.rank_left + rl2, i, 0, c2.get(rl2, i, 0));
295                    }
296                } else {
297                    // Middle core: block diagonal
298                    for rl1 in 0..c1.rank_left {
299                        for rr1 in 0..c1.rank_right {
300                            new_core.set(rl1, i, rr1, c1.get(rl1, i, rr1));
301                        }
302                    }
303                    for rl2 in 0..c2.rank_left {
304                        for rr2 in 0..c2.rank_right {
305                            new_core.set(
306                                c1.rank_left + rl2,
307                                i,
308                                c1.rank_right + rr2,
309                                c2.get(rl2, i, rr2),
310                            );
311                        }
312                    }
313                }
314            }
315
316            new_cores.push(new_core);
317        }
318
319        TensorTrain::from_cores(new_cores)
320    }
321
322    /// Scale by a constant
323    pub fn scale(&self, alpha: f64) -> TensorTrain {
324        let mut new_cores = self.cores.clone();
325
326        // Scale first core only
327        for val in new_cores[0].data.iter_mut() {
328            *val *= alpha;
329        }
330
331        TensorTrain::from_cores(new_cores)
332    }
333
334    /// TT-SVD decomposition from dense tensor
335    pub fn from_dense(tensor: &DenseTensor, config: &TensorTrainConfig) -> Self {
336        let d = tensor.order();
337        if d == 0 {
338            return TensorTrain::from_cores(vec![]);
339        }
340
341        let mut cores = Vec::new();
342        let mut c = tensor.data.clone();
343        let mut remaining_shape = tensor.shape.clone();
344        let mut left_rank = 1usize;
345
346        for k in 0..d - 1 {
347            let n_k = remaining_shape[0];
348            let rest_size: usize = remaining_shape[1..].iter().product();
349
350            // Reshape C to (left_rank * n_k) × rest_size
351            let rows = left_rank * n_k;
352            let cols = rest_size;
353
354            // Simple SVD via power iteration (for demonstration)
355            let (u, s, vt, new_rank) = simple_svd(&c, rows, cols, config);
356
357            // Create core from U
358            let core = TTCore::new(u, left_rank, n_k, new_rank);
359            cores.push(core);
360
361            // C = S * Vt for next iteration
362            c = Vec::with_capacity(new_rank * cols);
363            for i in 0..new_rank {
364                for j in 0..cols {
365                    c.push(s[i] * vt[i * cols + j]);
366                }
367            }
368
369            left_rank = new_rank;
370            remaining_shape.remove(0);
371        }
372
373        // Last core
374        let n_d = remaining_shape[0];
375        let last_core = TTCore::new(c, left_rank, n_d, 1);
376        cores.push(last_core);
377
378        TensorTrain::from_cores(cores)
379    }
380}
381
382/// Simple truncated SVD using power iteration
383/// Returns (U, S, Vt, rank)
384fn simple_svd(
385    a: &[f64],
386    rows: usize,
387    cols: usize,
388    config: &TensorTrainConfig,
389) -> (Vec<f64>, Vec<f64>, Vec<f64>, usize) {
390    let max_rank = if config.max_rank > 0 {
391        config.max_rank.min(rows).min(cols)
392    } else {
393        rows.min(cols)
394    };
395
396    let mut u = Vec::new();
397    let mut s = Vec::new();
398    let mut vt = Vec::new();
399
400    let mut a_residual = a.to_vec();
401
402    for _ in 0..max_rank {
403        // Power iteration to find top singular vector
404        let (sigma, u_vec, v_vec) = power_iteration(&a_residual, rows, cols, 20);
405
406        if sigma < config.tolerance {
407            break;
408        }
409
410        s.push(sigma);
411        u.extend(u_vec.iter());
412        vt.extend(v_vec.iter());
413
414        // Deflate: A = A - sigma * u * v^T
415        for i in 0..rows {
416            for j in 0..cols {
417                a_residual[i * cols + j] -= sigma * u_vec[i] * v_vec[j];
418            }
419        }
420    }
421
422    let rank = s.len();
423    (u, s, vt, rank.max(1))
424}
425
426/// Power iteration for largest singular value
427fn power_iteration(
428    a: &[f64],
429    rows: usize,
430    cols: usize,
431    max_iter: usize,
432) -> (f64, Vec<f64>, Vec<f64>) {
433    // Initialize random v
434    let mut v: Vec<f64> = (0..cols)
435        .map(|i| ((i * 2654435769) as f64 / 4294967296.0) * 2.0 - 1.0)
436        .collect();
437    normalize(&mut v);
438
439    let mut u = vec![0.0; rows];
440
441    for _ in 0..max_iter {
442        // u = A * v
443        for i in 0..rows {
444            u[i] = 0.0;
445            for j in 0..cols {
446                u[i] += a[i * cols + j] * v[j];
447            }
448        }
449        normalize(&mut u);
450
451        // v = A^T * u
452        for j in 0..cols {
453            v[j] = 0.0;
454            for i in 0..rows {
455                v[j] += a[i * cols + j] * u[i];
456            }
457        }
458        normalize(&mut v);
459    }
460
461    // Compute singular value
462    let mut av = vec![0.0; rows];
463    for i in 0..rows {
464        for j in 0..cols {
465            av[i] += a[i * cols + j] * v[j];
466        }
467    }
468    let sigma: f64 = u.iter().zip(av.iter()).map(|(ui, avi)| ui * avi).sum();
469
470    (sigma.abs(), u, v)
471}
472
473fn normalize(v: &mut [f64]) {
474    let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
475    if norm > 1e-15 {
476        for x in v.iter_mut() {
477            *x /= norm;
478        }
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    #[test]
487    fn test_tt_eval() {
488        // Rank-1 TT representing outer product of [1,2] and [3,4]
489        let v1 = vec![1.0, 2.0];
490        let v2 = vec![3.0, 4.0];
491        let tt = TensorTrain::from_vectors(vec![v1, v2]);
492
493        // Should equal v1[i] * v2[j]
494        assert!((tt.eval(&[0, 0]) - 3.0).abs() < 1e-10);
495        assert!((tt.eval(&[0, 1]) - 4.0).abs() < 1e-10);
496        assert!((tt.eval(&[1, 0]) - 6.0).abs() < 1e-10);
497        assert!((tt.eval(&[1, 1]) - 8.0).abs() < 1e-10);
498    }
499
500    #[test]
501    fn test_tt_dot() {
502        let v1 = vec![1.0, 2.0];
503        let v2 = vec![3.0, 4.0];
504        let tt = TensorTrain::from_vectors(vec![v1, v2]);
505
506        // <A, A> = sum of squares
507        let norm_sq = tt.dot(&tt);
508        // Elements: 3, 4, 6, 8 -> sum of squares = 9 + 16 + 36 + 64 = 125
509        assert!((norm_sq - 125.0).abs() < 1e-10);
510    }
511
512    #[test]
513    fn test_tt_from_dense() {
514        let tensor = DenseTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
515        let tt = TensorTrain::from_dense(&tensor, &TensorTrainConfig::default());
516
517        // Check reconstruction
518        let reconstructed = tt.to_dense();
519        let error: f64 = tensor
520            .data
521            .iter()
522            .zip(reconstructed.data.iter())
523            .map(|(a, b)| (a - b).powi(2))
524            .sum::<f64>()
525            .sqrt();
526
527        assert!(error < 1e-6);
528    }
529
530    #[test]
531    fn test_tt_add() {
532        let v1 = vec![1.0, 2.0];
533        let v2 = vec![3.0, 4.0];
534        let tt1 = TensorTrain::from_vectors(vec![v1.clone(), v2.clone()]);
535        let tt2 = TensorTrain::from_vectors(vec![v1, v2]);
536
537        let sum = tt1.add(&tt2);
538
539        // Should be 2 * tt1
540        assert!((sum.eval(&[0, 0]) - 6.0).abs() < 1e-10);
541        assert!((sum.eval(&[1, 1]) - 16.0).abs() < 1e-10);
542    }
543}