ruvector_math/tensor_networks/
tensor_train.rs1use super::DenseTensor;
17
18#[derive(Debug, Clone)]
20pub struct TensorTrainConfig {
21 pub max_rank: usize,
23 pub tolerance: f64,
25}
26
27impl Default for TensorTrainConfig {
28 fn default() -> Self {
29 Self {
30 max_rank: 0,
31 tolerance: 1e-12,
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
38pub struct TTCore {
39 pub data: Vec<f64>,
41 pub rank_left: usize,
43 pub mode_size: usize,
45 pub rank_right: usize,
47}
48
49impl TTCore {
50 pub fn new(data: Vec<f64>, rank_left: usize, mode_size: usize, rank_right: usize) -> Self {
52 assert_eq!(data.len(), rank_left * mode_size * rank_right);
53 Self {
54 data,
55 rank_left,
56 mode_size,
57 rank_right,
58 }
59 }
60
61 pub fn zeros(rank_left: usize, mode_size: usize, rank_right: usize) -> Self {
63 Self {
64 data: vec![0.0; rank_left * mode_size * rank_right],
65 rank_left,
66 mode_size,
67 rank_right,
68 }
69 }
70
71 pub fn get_matrix(&self, i: usize) -> Vec<f64> {
73 let start = i * self.rank_left * self.rank_right;
74 let end = start + self.rank_left * self.rank_right;
75
76 let mut result = vec![0.0; self.rank_left * self.rank_right];
79 for rl in 0..self.rank_left {
80 for rr in 0..self.rank_right {
81 let idx = rl * self.mode_size * self.rank_right + i * self.rank_right + rr;
82 result[rl * self.rank_right + rr] = self.data[idx];
83 }
84 }
85 result
86 }
87
88 pub fn set(&mut self, rl: usize, i: usize, rr: usize, value: f64) {
90 let idx = rl * self.mode_size * self.rank_right + i * self.rank_right + rr;
91 self.data[idx] = value;
92 }
93
94 pub fn get(&self, rl: usize, i: usize, rr: usize) -> f64 {
96 let idx = rl * self.mode_size * self.rank_right + i * self.rank_right + rr;
97 self.data[idx]
98 }
99}
100
101#[derive(Debug, Clone)]
103pub struct TensorTrain {
104 pub cores: Vec<TTCore>,
106 pub shape: Vec<usize>,
108 pub ranks: Vec<usize>,
110}
111
112impl TensorTrain {
113 pub fn from_cores(cores: Vec<TTCore>) -> Self {
115 let shape: Vec<usize> = cores.iter().map(|c| c.mode_size).collect();
116 let mut ranks = vec![1];
117 for core in &cores {
118 ranks.push(core.rank_right);
119 }
120
121 Self {
122 cores,
123 shape,
124 ranks,
125 }
126 }
127
128 pub fn from_vectors(vectors: Vec<Vec<f64>>) -> Self {
130 let cores: Vec<TTCore> = vectors
131 .into_iter()
132 .map(|v| {
133 let n = v.len();
134 TTCore::new(v, 1, n, 1)
135 })
136 .collect();
137
138 Self::from_cores(cores)
139 }
140
141 pub fn order(&self) -> usize {
143 self.shape.len()
144 }
145
146 pub fn max_rank(&self) -> usize {
148 self.ranks.iter().cloned().max().unwrap_or(1)
149 }
150
151 pub fn storage(&self) -> usize {
153 self.cores.iter().map(|c| c.data.len()).sum()
154 }
155
156 pub fn eval(&self, indices: &[usize]) -> f64 {
158 assert_eq!(indices.len(), self.order());
159
160 let mut result = vec![1.0];
162 let mut current_size = 1;
163
164 for (k, &idx) in indices.iter().enumerate() {
165 let core = &self.cores[k];
166 let new_size = core.rank_right;
167 let mut new_result = vec![0.0; new_size];
168
169 for rr in 0..new_size {
171 for rl in 0..current_size {
172 new_result[rr] += result[rl] * core.get(rl, idx, rr);
173 }
174 }
175
176 result = new_result;
177 current_size = new_size;
178 }
179
180 result[0]
181 }
182
183 pub fn to_dense(&self) -> DenseTensor {
185 let total_size: usize = self.shape.iter().product();
186 let mut data = vec![0.0; total_size];
187
188 let mut indices = vec![0usize; self.order()];
190 for flat_idx in 0..total_size {
191 data[flat_idx] = self.eval(&indices);
192
193 for k in (0..self.order()).rev() {
195 indices[k] += 1;
196 if indices[k] < self.shape[k] {
197 break;
198 }
199 indices[k] = 0;
200 }
201 }
202
203 DenseTensor::new(data, self.shape.clone())
204 }
205
206 pub fn dot(&self, other: &TensorTrain) -> f64 {
208 assert_eq!(self.shape, other.shape);
209
210 let mut z = vec![1.0]; let mut z_rows = 1;
214 let mut z_cols = 1;
215
216 for k in 0..self.order() {
217 let c1 = &self.cores[k];
218 let c2 = &other.cores[k];
219 let n = c1.mode_size;
220
221 let new_rows = c1.rank_right;
222 let new_cols = c2.rank_right;
223 let mut new_z = vec![0.0; new_rows * new_cols];
224
225 for i in 0..n {
227 for r1l in 0..z_rows {
228 for r2l in 0..z_cols {
229 let z_val = z[r1l * z_cols + r2l];
230
231 for r1r in 0..c1.rank_right {
232 for r2r in 0..c2.rank_right {
233 new_z[r1r * new_cols + r2r] +=
234 z_val * c1.get(r1l, i, r1r) * c2.get(r2l, i, r2r);
235 }
236 }
237 }
238 }
239 }
240
241 z = new_z;
242 z_rows = new_rows;
243 z_cols = new_cols;
244 }
245
246 z[0]
247 }
248
249 pub fn frobenius_norm(&self) -> f64 {
251 self.dot(self).sqrt()
252 }
253
254 pub fn add(&self, other: &TensorTrain) -> TensorTrain {
256 assert_eq!(self.shape, other.shape);
257
258 let mut new_cores = Vec::new();
259
260 for k in 0..self.order() {
261 let c1 = &self.cores[k];
262 let c2 = &other.cores[k];
263
264 let new_rl = if k == 0 {
265 1
266 } else {
267 c1.rank_left + c2.rank_left
268 };
269 let new_rr = if k == self.order() - 1 {
270 1
271 } else {
272 c1.rank_right + c2.rank_right
273 };
274 let n = c1.mode_size;
275
276 let mut new_data = vec![0.0; new_rl * n * new_rr];
277 let mut new_core = TTCore::new(new_data.clone(), new_rl, n, new_rr);
278
279 for i in 0..n {
280 if k == 0 {
281 for rr1 in 0..c1.rank_right {
283 new_core.set(0, i, rr1, c1.get(0, i, rr1));
284 }
285 for rr2 in 0..c2.rank_right {
286 new_core.set(0, i, c1.rank_right + rr2, c2.get(0, i, rr2));
287 }
288 } else if k == self.order() - 1 {
289 for rl1 in 0..c1.rank_left {
291 new_core.set(rl1, i, 0, c1.get(rl1, i, 0));
292 }
293 for rl2 in 0..c2.rank_left {
294 new_core.set(c1.rank_left + rl2, i, 0, c2.get(rl2, i, 0));
295 }
296 } else {
297 for rl1 in 0..c1.rank_left {
299 for rr1 in 0..c1.rank_right {
300 new_core.set(rl1, i, rr1, c1.get(rl1, i, rr1));
301 }
302 }
303 for rl2 in 0..c2.rank_left {
304 for rr2 in 0..c2.rank_right {
305 new_core.set(
306 c1.rank_left + rl2,
307 i,
308 c1.rank_right + rr2,
309 c2.get(rl2, i, rr2),
310 );
311 }
312 }
313 }
314 }
315
316 new_cores.push(new_core);
317 }
318
319 TensorTrain::from_cores(new_cores)
320 }
321
322 pub fn scale(&self, alpha: f64) -> TensorTrain {
324 let mut new_cores = self.cores.clone();
325
326 for val in new_cores[0].data.iter_mut() {
328 *val *= alpha;
329 }
330
331 TensorTrain::from_cores(new_cores)
332 }
333
334 pub fn from_dense(tensor: &DenseTensor, config: &TensorTrainConfig) -> Self {
336 let d = tensor.order();
337 if d == 0 {
338 return TensorTrain::from_cores(vec![]);
339 }
340
341 let mut cores = Vec::new();
342 let mut c = tensor.data.clone();
343 let mut remaining_shape = tensor.shape.clone();
344 let mut left_rank = 1usize;
345
346 for k in 0..d - 1 {
347 let n_k = remaining_shape[0];
348 let rest_size: usize = remaining_shape[1..].iter().product();
349
350 let rows = left_rank * n_k;
352 let cols = rest_size;
353
354 let (u, s, vt, new_rank) = simple_svd(&c, rows, cols, config);
356
357 let core = TTCore::new(u, left_rank, n_k, new_rank);
359 cores.push(core);
360
361 c = Vec::with_capacity(new_rank * cols);
363 for i in 0..new_rank {
364 for j in 0..cols {
365 c.push(s[i] * vt[i * cols + j]);
366 }
367 }
368
369 left_rank = new_rank;
370 remaining_shape.remove(0);
371 }
372
373 let n_d = remaining_shape[0];
375 let last_core = TTCore::new(c, left_rank, n_d, 1);
376 cores.push(last_core);
377
378 TensorTrain::from_cores(cores)
379 }
380}
381
382fn simple_svd(
385 a: &[f64],
386 rows: usize,
387 cols: usize,
388 config: &TensorTrainConfig,
389) -> (Vec<f64>, Vec<f64>, Vec<f64>, usize) {
390 let max_rank = if config.max_rank > 0 {
391 config.max_rank.min(rows).min(cols)
392 } else {
393 rows.min(cols)
394 };
395
396 let mut u = Vec::new();
397 let mut s = Vec::new();
398 let mut vt = Vec::new();
399
400 let mut a_residual = a.to_vec();
401
402 for _ in 0..max_rank {
403 let (sigma, u_vec, v_vec) = power_iteration(&a_residual, rows, cols, 20);
405
406 if sigma < config.tolerance {
407 break;
408 }
409
410 s.push(sigma);
411 u.extend(u_vec.iter());
412 vt.extend(v_vec.iter());
413
414 for i in 0..rows {
416 for j in 0..cols {
417 a_residual[i * cols + j] -= sigma * u_vec[i] * v_vec[j];
418 }
419 }
420 }
421
422 let rank = s.len();
423 (u, s, vt, rank.max(1))
424}
425
426fn power_iteration(
428 a: &[f64],
429 rows: usize,
430 cols: usize,
431 max_iter: usize,
432) -> (f64, Vec<f64>, Vec<f64>) {
433 let mut v: Vec<f64> = (0..cols)
435 .map(|i| ((i * 2654435769) as f64 / 4294967296.0) * 2.0 - 1.0)
436 .collect();
437 normalize(&mut v);
438
439 let mut u = vec![0.0; rows];
440
441 for _ in 0..max_iter {
442 for i in 0..rows {
444 u[i] = 0.0;
445 for j in 0..cols {
446 u[i] += a[i * cols + j] * v[j];
447 }
448 }
449 normalize(&mut u);
450
451 for j in 0..cols {
453 v[j] = 0.0;
454 for i in 0..rows {
455 v[j] += a[i * cols + j] * u[i];
456 }
457 }
458 normalize(&mut v);
459 }
460
461 let mut av = vec![0.0; rows];
463 for i in 0..rows {
464 for j in 0..cols {
465 av[i] += a[i * cols + j] * v[j];
466 }
467 }
468 let sigma: f64 = u.iter().zip(av.iter()).map(|(ui, avi)| ui * avi).sum();
469
470 (sigma.abs(), u, v)
471}
472
473fn normalize(v: &mut [f64]) {
474 let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
475 if norm > 1e-15 {
476 for x in v.iter_mut() {
477 *x /= norm;
478 }
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485
486 #[test]
487 fn test_tt_eval() {
488 let v1 = vec![1.0, 2.0];
490 let v2 = vec![3.0, 4.0];
491 let tt = TensorTrain::from_vectors(vec![v1, v2]);
492
493 assert!((tt.eval(&[0, 0]) - 3.0).abs() < 1e-10);
495 assert!((tt.eval(&[0, 1]) - 4.0).abs() < 1e-10);
496 assert!((tt.eval(&[1, 0]) - 6.0).abs() < 1e-10);
497 assert!((tt.eval(&[1, 1]) - 8.0).abs() < 1e-10);
498 }
499
500 #[test]
501 fn test_tt_dot() {
502 let v1 = vec![1.0, 2.0];
503 let v2 = vec![3.0, 4.0];
504 let tt = TensorTrain::from_vectors(vec![v1, v2]);
505
506 let norm_sq = tt.dot(&tt);
508 assert!((norm_sq - 125.0).abs() < 1e-10);
510 }
511
512 #[test]
513 fn test_tt_from_dense() {
514 let tensor = DenseTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
515 let tt = TensorTrain::from_dense(&tensor, &TensorTrainConfig::default());
516
517 let reconstructed = tt.to_dense();
519 let error: f64 = tensor
520 .data
521 .iter()
522 .zip(reconstructed.data.iter())
523 .map(|(a, b)| (a - b).powi(2))
524 .sum::<f64>()
525 .sqrt();
526
527 assert!(error < 1e-6);
528 }
529
530 #[test]
531 fn test_tt_add() {
532 let v1 = vec![1.0, 2.0];
533 let v2 = vec![3.0, 4.0];
534 let tt1 = TensorTrain::from_vectors(vec![v1.clone(), v2.clone()]);
535 let tt2 = TensorTrain::from_vectors(vec![v1, v2]);
536
537 let sum = tt1.add(&tt2);
538
539 assert!((sum.eval(&[0, 0]) - 6.0).abs() < 1e-10);
541 assert!((sum.eval(&[1, 1]) - 16.0).abs() < 1e-10);
542 }
543}