ruvector_math/tensor_networks/
tensor_train.rs1use super::DenseTensor;
17
18#[derive(Debug, Clone)]
20pub struct TensorTrainConfig {
21 pub max_rank: usize,
23 pub tolerance: f64,
25}
26
27impl Default for TensorTrainConfig {
28 fn default() -> Self {
29 Self {
30 max_rank: 0,
31 tolerance: 1e-12,
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
38pub struct TTCore {
39 pub data: Vec<f64>,
41 pub rank_left: usize,
43 pub mode_size: usize,
45 pub rank_right: usize,
47}
48
49impl TTCore {
50 pub fn new(data: Vec<f64>, rank_left: usize, mode_size: usize, rank_right: usize) -> Self {
52 assert_eq!(data.len(), rank_left * mode_size * rank_right);
53 Self { data, rank_left, mode_size, rank_right }
54 }
55
56 pub fn zeros(rank_left: usize, mode_size: usize, rank_right: usize) -> Self {
58 Self {
59 data: vec![0.0; rank_left * mode_size * rank_right],
60 rank_left,
61 mode_size,
62 rank_right,
63 }
64 }
65
66 pub fn get_matrix(&self, i: usize) -> Vec<f64> {
68 let start = i * self.rank_left * self.rank_right;
69 let end = start + self.rank_left * self.rank_right;
70
71 let mut result = vec![0.0; self.rank_left * self.rank_right];
74 for rl in 0..self.rank_left {
75 for rr in 0..self.rank_right {
76 let idx = rl * self.mode_size * self.rank_right + i * self.rank_right + rr;
77 result[rl * self.rank_right + rr] = self.data[idx];
78 }
79 }
80 result
81 }
82
83 pub fn set(&mut self, rl: usize, i: usize, rr: usize, value: f64) {
85 let idx = rl * self.mode_size * self.rank_right + i * self.rank_right + rr;
86 self.data[idx] = value;
87 }
88
89 pub fn get(&self, rl: usize, i: usize, rr: usize) -> f64 {
91 let idx = rl * self.mode_size * self.rank_right + i * self.rank_right + rr;
92 self.data[idx]
93 }
94}
95
96#[derive(Debug, Clone)]
98pub struct TensorTrain {
99 pub cores: Vec<TTCore>,
101 pub shape: Vec<usize>,
103 pub ranks: Vec<usize>,
105}
106
107impl TensorTrain {
108 pub fn from_cores(cores: Vec<TTCore>) -> Self {
110 let shape: Vec<usize> = cores.iter().map(|c| c.mode_size).collect();
111 let mut ranks = vec![1];
112 for core in &cores {
113 ranks.push(core.rank_right);
114 }
115
116 Self { cores, shape, ranks }
117 }
118
119 pub fn from_vectors(vectors: Vec<Vec<f64>>) -> Self {
121 let cores: Vec<TTCore> = vectors
122 .into_iter()
123 .map(|v| {
124 let n = v.len();
125 TTCore::new(v, 1, n, 1)
126 })
127 .collect();
128
129 Self::from_cores(cores)
130 }
131
132 pub fn order(&self) -> usize {
134 self.shape.len()
135 }
136
137 pub fn max_rank(&self) -> usize {
139 self.ranks.iter().cloned().max().unwrap_or(1)
140 }
141
142 pub fn storage(&self) -> usize {
144 self.cores.iter().map(|c| c.data.len()).sum()
145 }
146
147 pub fn eval(&self, indices: &[usize]) -> f64 {
149 assert_eq!(indices.len(), self.order());
150
151 let mut result = vec![1.0];
153 let mut current_size = 1;
154
155 for (k, &idx) in indices.iter().enumerate() {
156 let core = &self.cores[k];
157 let new_size = core.rank_right;
158 let mut new_result = vec![0.0; new_size];
159
160 for rr in 0..new_size {
162 for rl in 0..current_size {
163 new_result[rr] += result[rl] * core.get(rl, idx, rr);
164 }
165 }
166
167 result = new_result;
168 current_size = new_size;
169 }
170
171 result[0]
172 }
173
174 pub fn to_dense(&self) -> DenseTensor {
176 let total_size: usize = self.shape.iter().product();
177 let mut data = vec![0.0; total_size];
178
179 let mut indices = vec![0usize; self.order()];
181 for flat_idx in 0..total_size {
182 data[flat_idx] = self.eval(&indices);
183
184 for k in (0..self.order()).rev() {
186 indices[k] += 1;
187 if indices[k] < self.shape[k] {
188 break;
189 }
190 indices[k] = 0;
191 }
192 }
193
194 DenseTensor::new(data, self.shape.clone())
195 }
196
197 pub fn dot(&self, other: &TensorTrain) -> f64 {
199 assert_eq!(self.shape, other.shape);
200
201 let mut z = vec![1.0]; let mut z_rows = 1;
205 let mut z_cols = 1;
206
207 for k in 0..self.order() {
208 let c1 = &self.cores[k];
209 let c2 = &other.cores[k];
210 let n = c1.mode_size;
211
212 let new_rows = c1.rank_right;
213 let new_cols = c2.rank_right;
214 let mut new_z = vec![0.0; new_rows * new_cols];
215
216 for i in 0..n {
218 for r1l in 0..z_rows {
219 for r2l in 0..z_cols {
220 let z_val = z[r1l * z_cols + r2l];
221
222 for r1r in 0..c1.rank_right {
223 for r2r in 0..c2.rank_right {
224 new_z[r1r * new_cols + r2r] +=
225 z_val * c1.get(r1l, i, r1r) * c2.get(r2l, i, r2r);
226 }
227 }
228 }
229 }
230 }
231
232 z = new_z;
233 z_rows = new_rows;
234 z_cols = new_cols;
235 }
236
237 z[0]
238 }
239
240 pub fn frobenius_norm(&self) -> f64 {
242 self.dot(self).sqrt()
243 }
244
245 pub fn add(&self, other: &TensorTrain) -> TensorTrain {
247 assert_eq!(self.shape, other.shape);
248
249 let mut new_cores = Vec::new();
250
251 for k in 0..self.order() {
252 let c1 = &self.cores[k];
253 let c2 = &other.cores[k];
254
255 let new_rl = if k == 0 { 1 } else { c1.rank_left + c2.rank_left };
256 let new_rr = if k == self.order() - 1 { 1 } else { c1.rank_right + c2.rank_right };
257 let n = c1.mode_size;
258
259 let mut new_data = vec![0.0; new_rl * n * new_rr];
260 let mut new_core = TTCore::new(new_data.clone(), new_rl, n, new_rr);
261
262 for i in 0..n {
263 if k == 0 {
264 for rr1 in 0..c1.rank_right {
266 new_core.set(0, i, rr1, c1.get(0, i, rr1));
267 }
268 for rr2 in 0..c2.rank_right {
269 new_core.set(0, i, c1.rank_right + rr2, c2.get(0, i, rr2));
270 }
271 } else if k == self.order() - 1 {
272 for rl1 in 0..c1.rank_left {
274 new_core.set(rl1, i, 0, c1.get(rl1, i, 0));
275 }
276 for rl2 in 0..c2.rank_left {
277 new_core.set(c1.rank_left + rl2, i, 0, c2.get(rl2, i, 0));
278 }
279 } else {
280 for rl1 in 0..c1.rank_left {
282 for rr1 in 0..c1.rank_right {
283 new_core.set(rl1, i, rr1, c1.get(rl1, i, rr1));
284 }
285 }
286 for rl2 in 0..c2.rank_left {
287 for rr2 in 0..c2.rank_right {
288 new_core.set(c1.rank_left + rl2, i, c1.rank_right + rr2, c2.get(rl2, i, rr2));
289 }
290 }
291 }
292 }
293
294 new_cores.push(new_core);
295 }
296
297 TensorTrain::from_cores(new_cores)
298 }
299
300 pub fn scale(&self, alpha: f64) -> TensorTrain {
302 let mut new_cores = self.cores.clone();
303
304 for val in new_cores[0].data.iter_mut() {
306 *val *= alpha;
307 }
308
309 TensorTrain::from_cores(new_cores)
310 }
311
312 pub fn from_dense(tensor: &DenseTensor, config: &TensorTrainConfig) -> Self {
314 let d = tensor.order();
315 if d == 0 {
316 return TensorTrain::from_cores(vec![]);
317 }
318
319 let mut cores = Vec::new();
320 let mut c = tensor.data.clone();
321 let mut remaining_shape = tensor.shape.clone();
322 let mut left_rank = 1usize;
323
324 for k in 0..d - 1 {
325 let n_k = remaining_shape[0];
326 let rest_size: usize = remaining_shape[1..].iter().product();
327
328 let rows = left_rank * n_k;
330 let cols = rest_size;
331
332 let (u, s, vt, new_rank) = simple_svd(&c, rows, cols, config);
334
335 let core = TTCore::new(u, left_rank, n_k, new_rank);
337 cores.push(core);
338
339 c = Vec::with_capacity(new_rank * cols);
341 for i in 0..new_rank {
342 for j in 0..cols {
343 c.push(s[i] * vt[i * cols + j]);
344 }
345 }
346
347 left_rank = new_rank;
348 remaining_shape.remove(0);
349 }
350
351 let n_d = remaining_shape[0];
353 let last_core = TTCore::new(c, left_rank, n_d, 1);
354 cores.push(last_core);
355
356 TensorTrain::from_cores(cores)
357 }
358}
359
360fn simple_svd(a: &[f64], rows: usize, cols: usize, config: &TensorTrainConfig) -> (Vec<f64>, Vec<f64>, Vec<f64>, usize) {
363 let max_rank = if config.max_rank > 0 {
364 config.max_rank.min(rows).min(cols)
365 } else {
366 rows.min(cols)
367 };
368
369 let mut u = Vec::new();
370 let mut s = Vec::new();
371 let mut vt = Vec::new();
372
373 let mut a_residual = a.to_vec();
374
375 for _ in 0..max_rank {
376 let (sigma, u_vec, v_vec) = power_iteration(&a_residual, rows, cols, 20);
378
379 if sigma < config.tolerance {
380 break;
381 }
382
383 s.push(sigma);
384 u.extend(u_vec.iter());
385 vt.extend(v_vec.iter());
386
387 for i in 0..rows {
389 for j in 0..cols {
390 a_residual[i * cols + j] -= sigma * u_vec[i] * v_vec[j];
391 }
392 }
393 }
394
395 let rank = s.len();
396 (u, s, vt, rank.max(1))
397}
398
399fn power_iteration(a: &[f64], rows: usize, cols: usize, max_iter: usize) -> (f64, Vec<f64>, Vec<f64>) {
401 let mut v: Vec<f64> = (0..cols).map(|i| ((i * 2654435769) as f64 / 4294967296.0) * 2.0 - 1.0).collect();
403 normalize(&mut v);
404
405 let mut u = vec![0.0; rows];
406
407 for _ in 0..max_iter {
408 for i in 0..rows {
410 u[i] = 0.0;
411 for j in 0..cols {
412 u[i] += a[i * cols + j] * v[j];
413 }
414 }
415 normalize(&mut u);
416
417 for j in 0..cols {
419 v[j] = 0.0;
420 for i in 0..rows {
421 v[j] += a[i * cols + j] * u[i];
422 }
423 }
424 normalize(&mut v);
425 }
426
427 let mut av = vec![0.0; rows];
429 for i in 0..rows {
430 for j in 0..cols {
431 av[i] += a[i * cols + j] * v[j];
432 }
433 }
434 let sigma: f64 = u.iter().zip(av.iter()).map(|(ui, avi)| ui * avi).sum();
435
436 (sigma.abs(), u, v)
437}
438
439fn normalize(v: &mut [f64]) {
440 let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
441 if norm > 1e-15 {
442 for x in v.iter_mut() {
443 *x /= norm;
444 }
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451
452 #[test]
453 fn test_tt_eval() {
454 let v1 = vec![1.0, 2.0];
456 let v2 = vec![3.0, 4.0];
457 let tt = TensorTrain::from_vectors(vec![v1, v2]);
458
459 assert!((tt.eval(&[0, 0]) - 3.0).abs() < 1e-10);
461 assert!((tt.eval(&[0, 1]) - 4.0).abs() < 1e-10);
462 assert!((tt.eval(&[1, 0]) - 6.0).abs() < 1e-10);
463 assert!((tt.eval(&[1, 1]) - 8.0).abs() < 1e-10);
464 }
465
466 #[test]
467 fn test_tt_dot() {
468 let v1 = vec![1.0, 2.0];
469 let v2 = vec![3.0, 4.0];
470 let tt = TensorTrain::from_vectors(vec![v1, v2]);
471
472 let norm_sq = tt.dot(&tt);
474 assert!((norm_sq - 125.0).abs() < 1e-10);
476 }
477
478 #[test]
479 fn test_tt_from_dense() {
480 let tensor = DenseTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
481 let tt = TensorTrain::from_dense(&tensor, &TensorTrainConfig::default());
482
483 let reconstructed = tt.to_dense();
485 let error: f64 = tensor.data.iter().zip(reconstructed.data.iter())
486 .map(|(a, b)| (a - b).powi(2))
487 .sum::<f64>()
488 .sqrt();
489
490 assert!(error < 1e-6);
491 }
492
493 #[test]
494 fn test_tt_add() {
495 let v1 = vec![1.0, 2.0];
496 let v2 = vec![3.0, 4.0];
497 let tt1 = TensorTrain::from_vectors(vec![v1.clone(), v2.clone()]);
498 let tt2 = TensorTrain::from_vectors(vec![v1, v2]);
499
500 let sum = tt1.add(&tt2);
501
502 assert!((sum.eval(&[0, 0]) - 6.0).abs() < 1e-10);
504 assert!((sum.eval(&[1, 1]) - 16.0).abs() < 1e-10);
505 }
506}