1use crate::error::{PgmError, Result};
36use crate::graph::FactorGraph;
37use crate::linear_chain_crf::LinearChainCRF;
38use quantrs2_sim::Complex64;
39use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayD};
40use serde::{Deserialize, Serialize};
41
42#[derive(Debug, Clone)]
44pub struct Tensor {
45 pub name: String,
47 pub data: ArrayD<Complex64>,
49 pub indices: Vec<String>,
51 pub bond_dims: Vec<usize>,
53}
54
55impl Tensor {
56 pub fn new(name: String, data: ArrayD<Complex64>, indices: Vec<String>) -> Self {
58 let bond_dims = data.shape().to_vec();
59 Self {
60 name,
61 data,
62 indices,
63 bond_dims,
64 }
65 }
66
67 pub fn from_real(name: String, data: ArrayD<f64>, indices: Vec<String>) -> Self {
69 let complex_data = data.mapv(|x| Complex64::new(x, 0.0));
70 Self::new(name, complex_data, indices)
71 }
72
73 pub fn rank(&self) -> usize {
75 self.indices.len()
76 }
77
78 pub fn bond_dim(&self, index: &str) -> Option<usize> {
80 self.indices
81 .iter()
82 .position(|i| i == index)
83 .map(|pos| self.bond_dims[pos])
84 }
85
86 pub fn contract(&self, other: &Tensor) -> Result<Tensor> {
88 let shared: Vec<(usize, usize)> = self
90 .indices
91 .iter()
92 .enumerate()
93 .filter_map(|(i, idx)| {
94 other
95 .indices
96 .iter()
97 .position(|oidx| oidx == idx)
98 .map(|j| (i, j))
99 })
100 .collect();
101
102 if shared.is_empty() {
103 return self.outer_product(other);
105 }
106
107 let result_indices: Vec<String> = self
109 .indices
110 .iter()
111 .enumerate()
112 .filter(|(i, _)| !shared.iter().any(|(si, _)| si == i))
113 .map(|(_, idx)| idx.clone())
114 .chain(
115 other
116 .indices
117 .iter()
118 .enumerate()
119 .filter(|(j, _)| !shared.iter().any(|(_, sj)| sj == j))
120 .map(|(_, idx)| idx.clone()),
121 )
122 .collect();
123
124 let result_shape: Vec<usize> = self
126 .bond_dims
127 .iter()
128 .enumerate()
129 .filter(|(i, _)| !shared.iter().any(|(si, _)| si == i))
130 .map(|(_, &d)| d)
131 .chain(
132 other
133 .bond_dims
134 .iter()
135 .enumerate()
136 .filter(|(j, _)| !shared.iter().any(|(_, sj)| sj == j))
137 .map(|(_, &d)| d),
138 )
139 .collect();
140
141 let result_data = self.contract_data(other, &shared, &result_shape)?;
144
145 Ok(Tensor {
146 name: format!("{}*{}", self.name, other.name),
147 data: result_data,
148 indices: result_indices,
149 bond_dims: result_shape,
150 })
151 }
152
153 fn contract_data(
155 &self,
156 _other: &Tensor,
157 _shared: &[(usize, usize)],
158 result_shape: &[usize],
159 ) -> Result<ArrayD<Complex64>> {
160 let total_size: usize = result_shape.iter().product();
163 let data = vec![Complex64::new(0.0, 0.0); total_size.max(1)];
164 ArrayD::from_shape_vec(result_shape.to_vec(), data)
165 .map_err(|e| PgmError::InvalidDistribution(format!("Contraction failed: {}", e)))
166 }
167
168 fn outer_product(&self, other: &Tensor) -> Result<Tensor> {
170 let result_indices: Vec<String> = self
171 .indices
172 .iter()
173 .chain(other.indices.iter())
174 .cloned()
175 .collect();
176
177 let result_shape: Vec<usize> = self
178 .bond_dims
179 .iter()
180 .chain(other.bond_dims.iter())
181 .copied()
182 .collect();
183
184 let total_size: usize = result_shape.iter().product();
185 let mut data = vec![Complex64::new(0.0, 0.0); total_size.max(1)];
186
187 for (i, &a) in self.data.iter().enumerate() {
189 for (j, &b) in other.data.iter().enumerate() {
190 data[i * other.data.len() + j] = a * b;
191 }
192 }
193
194 Ok(Tensor {
195 name: format!("{}⊗{}", self.name, other.name),
196 data: ArrayD::from_shape_vec(result_shape.clone(), data).map_err(|e| {
197 PgmError::InvalidDistribution(format!("Outer product failed: {}", e))
198 })?,
199 indices: result_indices,
200 bond_dims: result_shape,
201 })
202 }
203}
204
205#[derive(Debug, Clone)]
207pub struct TensorNetwork {
208 tensors: Vec<Tensor>,
210 physical_indices: Vec<String>,
212 bond_indices: Vec<String>,
214}
215
216impl TensorNetwork {
217 pub fn new() -> Self {
219 Self {
220 tensors: Vec::new(),
221 physical_indices: Vec::new(),
222 bond_indices: Vec::new(),
223 }
224 }
225
226 pub fn add_tensor(&mut self, tensor: Tensor) {
228 self.tensors.push(tensor);
229 }
230
231 pub fn add_physical_index(&mut self, index: String) {
233 if !self.physical_indices.contains(&index) {
234 self.physical_indices.push(index);
235 }
236 }
237
238 pub fn add_bond_index(&mut self, index: String) {
240 if !self.bond_indices.contains(&index) {
241 self.bond_indices.push(index);
242 }
243 }
244
245 pub fn num_tensors(&self) -> usize {
247 self.tensors.len()
248 }
249
250 pub fn num_physical_indices(&self) -> usize {
252 self.physical_indices.len()
253 }
254
255 pub fn total_bond_dim(&self) -> usize {
257 self.tensors
258 .iter()
259 .map(|t| t.bond_dims.iter().product::<usize>())
260 .sum()
261 }
262
263 pub fn contract(&self) -> Result<Tensor> {
267 if self.tensors.is_empty() {
268 return Err(PgmError::InvalidGraph(
269 "Cannot contract empty tensor network".to_string(),
270 ));
271 }
272
273 let mut result = self.tensors[0].clone();
274 for tensor in self.tensors.iter().skip(1) {
275 result = result.contract(tensor)?;
276 }
277
278 Ok(result)
279 }
280
281 pub fn partition_function(&self) -> Result<Complex64> {
283 let contracted = self.contract()?;
284 Ok(contracted.data.iter().sum())
285 }
286
287 pub fn marginal(&self, indices: &[String]) -> Result<Tensor> {
289 let contracted = self.contract()?;
291
292 let keep_positions: Vec<usize> = contracted
294 .indices
295 .iter()
296 .enumerate()
297 .filter_map(
298 |(i, idx)| {
299 if indices.contains(idx) {
300 Some(i)
301 } else {
302 None
303 }
304 },
305 )
306 .collect();
307
308 if keep_positions.is_empty() {
309 let sum: Complex64 = contracted.data.iter().sum();
311 return Ok(Tensor::new(
312 "marginal".to_string(),
313 ArrayD::from_elem(vec![], sum),
314 vec![],
315 ));
316 }
317
318 let result_shape: Vec<usize> = keep_positions
320 .iter()
321 .map(|&i| contracted.bond_dims[i])
322 .collect();
323 let result_indices: Vec<String> = keep_positions
324 .iter()
325 .map(|&i| contracted.indices[i].clone())
326 .collect();
327
328 Ok(Tensor {
330 name: "marginal".to_string(),
331 data: contracted.data, indices: result_indices,
333 bond_dims: result_shape,
334 })
335 }
336}
337
338impl Default for TensorNetwork {
339 fn default() -> Self {
340 Self::new()
341 }
342}
343
344pub fn factor_graph_to_tensor_network(graph: &FactorGraph) -> Result<TensorNetwork> {
348 let mut tn = TensorNetwork::new();
349
350 for var_name in graph.variable_names() {
352 tn.add_physical_index(var_name.clone());
353 }
354
355 for factor in graph.factors() {
357 let indices = factor.variables.clone();
358 let tensor = Tensor::from_real(factor.name.clone(), factor.values.clone(), indices);
359 tn.add_tensor(tensor);
360 }
361
362 Ok(tn)
363}
364
365#[derive(Debug, Clone)]
371pub struct MatrixProductState {
372 pub tensors: Vec<Array3<Complex64>>,
374 pub physical_dims: Vec<usize>,
376 pub bond_dims: Vec<usize>,
378}
379
380impl MatrixProductState {
381 pub fn new(length: usize, physical_dim: usize, bond_dim: usize) -> Self {
383 let mut tensors = Vec::with_capacity(length);
384 let mut bond_dims = Vec::with_capacity(length + 1);
385
386 bond_dims.push(1); for i in 0..length {
389 let left_dim = bond_dims[i];
390 let right_dim = if i == length - 1 { 1 } else { bond_dim };
391 bond_dims.push(right_dim);
392
393 let tensor = Array3::from_shape_fn((left_dim, physical_dim, right_dim), |_| {
395 Complex64::new(1.0 / (left_dim * physical_dim * right_dim) as f64, 0.0)
396 });
397 tensors.push(tensor);
398 }
399
400 Self {
401 tensors,
402 physical_dims: vec![physical_dim; length],
403 bond_dims,
404 }
405 }
406
407 pub fn product_state(length: usize, physical_dim: usize) -> Self {
409 let mut tensors = Vec::with_capacity(length);
410
411 for _ in 0..length {
412 let mut tensor = Array3::zeros((1, physical_dim, 1));
414 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
415 tensors.push(tensor);
416 }
417
418 Self {
419 tensors,
420 physical_dims: vec![physical_dim; length],
421 bond_dims: vec![1; length + 1],
422 }
423 }
424
425 pub fn length(&self) -> usize {
427 self.tensors.len()
428 }
429
430 pub fn max_bond_dim(&self) -> usize {
432 *self.bond_dims.iter().max().unwrap_or(&1)
433 }
434
435 pub fn to_state_vector(&self) -> Result<Array1<Complex64>> {
439 if self.tensors.is_empty() {
440 return Ok(Array1::from(vec![Complex64::new(1.0, 0.0)]));
441 }
442
443 let total_dim: usize = self.physical_dims.iter().product();
444 let mut state = Array1::zeros(total_dim);
445
446 for basis_idx in 0..total_dim {
448 let mut indices = vec![0; self.tensors.len()];
449 let mut temp = basis_idx;
450 for (i, &dim) in self.physical_dims.iter().enumerate().rev() {
451 indices[i] = temp % dim;
452 temp /= dim;
453 }
454
455 let mut amplitude = Complex64::new(1.0, 0.0);
457 let mut left_idx = 0;
458
459 for (site, &phys_idx) in indices.iter().enumerate() {
460 let tensor = &self.tensors[site];
461 let right_dim = tensor.shape()[2];
463 let mut sum = Complex64::new(0.0, 0.0);
464 for right_idx in 0..right_dim {
465 sum += tensor[[left_idx, phys_idx, right_idx]];
466 }
467 amplitude *= sum;
468 left_idx = 0; }
470
471 state[basis_idx] = amplitude;
472 }
473
474 let norm: f64 = state
476 .iter()
477 .map(|x: &Complex64| x.norm_sqr())
478 .sum::<f64>()
479 .sqrt();
480 if norm > 1e-10 {
481 for x in state.iter_mut() {
482 *x /= norm;
483 }
484 }
485
486 Ok(state)
487 }
488
489 pub fn norm(&self) -> f64 {
491 let state_result: Result<Array1<Complex64>> = self.to_state_vector();
492 match state_result {
493 Ok(state) => {
494 let state_arr: Array1<Complex64> = state;
495 state_arr
496 .iter()
497 .map(|x: &Complex64| x.norm_sqr())
498 .sum::<f64>()
499 .sqrt()
500 }
501 Err(_) => 0.0,
502 }
503 }
504
505 pub fn expectation_local(
507 &self,
508 site: usize,
509 operator: &Array2<Complex64>,
510 ) -> Result<Complex64> {
511 if site >= self.tensors.len() {
512 return Err(PgmError::VariableNotFound(format!(
513 "Site {} out of range",
514 site
515 )));
516 }
517
518 let state = self.to_state_vector()?;
520
521 let mut result = Complex64::new(0.0, 0.0);
522 let num_sites = self.tensors.len();
523 let total_dim: usize = self.physical_dims.iter().product();
524
525 for basis_idx in 0..total_dim {
526 let mut indices = vec![0; num_sites];
528 let mut temp = basis_idx;
529 for (i, &dim) in self.physical_dims.iter().enumerate().rev() {
530 indices[i] = temp % dim;
531 temp /= dim;
532 }
533
534 for new_idx in 0..self.physical_dims[site] {
536 let op_elem = operator[[new_idx, indices[site]]];
537 if op_elem.norm_sqr() > 1e-20 {
538 let mut new_basis_idx = 0;
540 let mut multiplier = 1;
541 for (i, &idx) in indices.iter().enumerate().rev() {
542 let idx_to_use = if i == site { new_idx } else { idx };
543 new_basis_idx += idx_to_use * multiplier;
544 multiplier *= self.physical_dims[i];
545 }
546
547 result += state[new_basis_idx].conj() * op_elem * state[basis_idx];
548 }
549 }
550 }
551
552 Ok(result)
553 }
554}
555
556pub fn linear_chain_to_mps(
560 crf: &LinearChainCRF,
561 input_sequence: &[usize],
562) -> Result<MatrixProductState> {
563 let factor_graph = crf.to_factor_graph(input_sequence)?;
564 let num_sites = input_sequence.len();
565
566 if num_sites == 0 {
567 return Err(PgmError::InvalidGraph("Empty sequence".to_string()));
568 }
569
570 let num_states = factor_graph
572 .get_variable("y_0")
573 .map(|v| v.cardinality)
574 .unwrap_or(2);
575
576 let mut mps = MatrixProductState::new(num_sites, num_states, num_states);
578
579 for t in 0..num_sites {
581 let emission_name = format!("emission_{}", t);
582 let transition_name = format!("transition_{}", t);
583
584 if let Some(emission) = factor_graph.get_factor_by_name(&emission_name) {
586 for (s, &val) in emission.values.iter().enumerate() {
588 if s < num_states {
589 mps.tensors[t][[0, s, 0]] = Complex64::new(val.sqrt(), 0.0);
590 }
591 }
592 }
593
594 if t > 0 {
596 if let Some(transition) = factor_graph.get_factor_by_name(&transition_name) {
597 for s_prev in 0..num_states {
599 for s_curr in 0..num_states {
600 if s_prev < transition.values.shape()[0]
601 && s_curr < transition.values.shape()[1]
602 {
603 let val = transition.values[[s_prev, s_curr]];
604 let tensor = &mut mps.tensors[t];
606 let left_dim = tensor.shape()[0];
607 let right_dim = tensor.shape()[2];
608 if s_prev < left_dim && s_curr < num_states {
609 tensor[[s_prev.min(left_dim - 1), s_curr, 0.min(right_dim - 1)]] =
610 Complex64::new(val.sqrt(), 0.0);
611 }
612 }
613 }
614 }
615 }
616 }
617 }
618
619 Ok(mps)
620}
621
622#[derive(Debug, Clone, Serialize, Deserialize)]
624pub struct TensorNetworkStats {
625 pub num_tensors: usize,
627 pub total_elements: usize,
629 pub max_rank: usize,
631 pub avg_rank: f64,
633 pub num_physical_indices: usize,
635 pub num_bond_indices: usize,
637}
638
639impl TensorNetwork {
640 pub fn stats(&self) -> TensorNetworkStats {
642 let num_tensors = self.tensors.len();
643 let total_elements: usize = self.tensors.iter().map(|t| t.data.len()).sum();
644 let max_rank = self.tensors.iter().map(|t| t.rank()).max().unwrap_or(0);
645 let avg_rank = if num_tensors > 0 {
646 self.tensors.iter().map(|t| t.rank()).sum::<usize>() as f64 / num_tensors as f64
647 } else {
648 0.0
649 };
650
651 TensorNetworkStats {
652 num_tensors,
653 total_elements,
654 max_rank,
655 avg_rank,
656 num_physical_indices: self.physical_indices.len(),
657 num_bond_indices: self.bond_indices.len(),
658 }
659 }
660}
661
662#[cfg(test)]
663mod tests {
664 use super::*;
665 use crate::graph::FactorGraph;
666 use approx::assert_abs_diff_eq;
667
668 #[test]
669 fn test_tensor_creation() {
670 let data = ArrayD::from_shape_vec(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
671 .expect("Array creation failed");
672 let tensor = Tensor::from_real(
673 "test".to_string(),
674 data,
675 vec!["i".to_string(), "j".to_string()],
676 );
677
678 assert_eq!(tensor.rank(), 2);
679 assert_eq!(tensor.bond_dim("i"), Some(2));
680 assert_eq!(tensor.bond_dim("j"), Some(3));
681 }
682
683 #[test]
684 fn test_tensor_network_creation() {
685 let mut tn = TensorNetwork::new();
686 let data = ArrayD::from_shape_vec(vec![2], vec![1.0, 0.0]).expect("Array creation failed");
687 let tensor = Tensor::from_real("A".to_string(), data, vec!["x".to_string()]);
688
689 tn.add_tensor(tensor);
690 tn.add_physical_index("x".to_string());
691
692 assert_eq!(tn.num_tensors(), 1);
693 assert_eq!(tn.num_physical_indices(), 1);
694 }
695
696 #[test]
697 fn test_factor_graph_to_tn() {
698 let mut graph = FactorGraph::new();
699 graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
700 graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
701
702 let tn = factor_graph_to_tensor_network(&graph);
703 assert!(tn.is_ok());
704
705 let tn = tn.expect("TN creation failed");
706 assert_eq!(tn.num_physical_indices(), 2);
707 }
708
709 #[test]
710 fn test_mps_creation() {
711 let mps = MatrixProductState::new(4, 2, 4);
712
713 assert_eq!(mps.length(), 4);
714 assert_eq!(mps.physical_dims.len(), 4);
715 assert!(mps.max_bond_dim() <= 4);
716 }
717
718 #[test]
719 fn test_mps_product_state() {
720 let mps = MatrixProductState::product_state(3, 2);
721
722 assert_eq!(mps.length(), 3);
723 assert_eq!(mps.max_bond_dim(), 1);
724
725 let norm = mps.norm();
727 assert_abs_diff_eq!(norm, 1.0, epsilon = 1e-6);
728 }
729
730 #[test]
731 fn test_mps_to_state_vector() {
732 let mps = MatrixProductState::product_state(2, 2);
733 let state = mps.to_state_vector();
734
735 assert!(state.is_ok());
736 let state = state.expect("State vector failed");
737 assert_eq!(state.len(), 4); }
739
740 #[test]
741 fn test_tensor_network_stats() {
742 let mut tn = TensorNetwork::new();
743 let data1 =
744 ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).expect("Array creation failed");
745 let data2 =
746 ArrayD::from_shape_vec(vec![3, 4], vec![1.0; 12]).expect("Array creation failed");
747
748 tn.add_tensor(Tensor::from_real(
749 "A".to_string(),
750 data1,
751 vec!["i".to_string(), "j".to_string()],
752 ));
753 tn.add_tensor(Tensor::from_real(
754 "B".to_string(),
755 data2,
756 vec!["j".to_string(), "k".to_string()],
757 ));
758
759 let stats = tn.stats();
760 assert_eq!(stats.num_tensors, 2);
761 assert_eq!(stats.total_elements, 18);
762 assert_eq!(stats.max_rank, 2);
763 }
764
765 #[test]
766 fn test_tensor_outer_product() {
767 let data1 = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).expect("Array creation failed");
768 let data2 =
769 ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).expect("Array creation failed");
770
771 let t1 = Tensor::from_real("A".to_string(), data1, vec!["i".to_string()]);
772 let t2 = Tensor::from_real("B".to_string(), data2, vec!["j".to_string()]);
773
774 let result = t1.contract(&t2);
775 assert!(result.is_ok());
776
777 let result = result.expect("Contraction failed");
778 assert_eq!(result.indices.len(), 2);
779 assert_eq!(result.bond_dims, vec![2, 3]);
780 }
781
782 #[test]
783 fn test_mps_expectation() {
784 let mps = MatrixProductState::product_state(2, 2);
785
786 let z_op = Array2::from_shape_vec(
788 (2, 2),
789 vec![
790 Complex64::new(1.0, 0.0),
791 Complex64::new(0.0, 0.0),
792 Complex64::new(0.0, 0.0),
793 Complex64::new(-1.0, 0.0),
794 ],
795 )
796 .expect("Operator creation failed");
797
798 let exp_val = mps.expectation_local(0, &z_op);
799 assert!(exp_val.is_ok());
800
801 let exp_val = exp_val.expect("Expectation failed");
803 assert_abs_diff_eq!(exp_val.re, 1.0, epsilon = 1e-6);
804 }
805}